diff --git a/tilelang/original/docs/.gitignore b/tilelang/original/docs/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4d8eb40499da61a00de91503a87038940f8a95d6 --- /dev/null +++ b/tilelang/original/docs/.gitignore @@ -0,0 +1,2 @@ +_build/ +autoapi/ \ No newline at end of file diff --git a/tilelang/original/docs/CNAME b/tilelang/original/docs/CNAME new file mode 100644 index 0000000000000000000000000000000000000000..ca903c694a195b577524d38b2b26cc577ab76bf9 --- /dev/null +++ b/tilelang/original/docs/CNAME @@ -0,0 +1 @@ +tilelang.com \ No newline at end of file diff --git a/tilelang/original/docs/Makefile b/tilelang/original/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..157adfb90438fa9f3bf061f876877d606acc4d0d --- /dev/null +++ b/tilelang/original/docs/Makefile @@ -0,0 +1,25 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= python -m sphinx +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile clean + +# The "clean" target is updated to remove the autoapi generated files as well. +# Run "make clean" to ensure a completely fresh build. +clean: + rm -rf $(BUILDDIR) autoapi + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/tilelang/original/docs/README.md b/tilelang/original/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..349c0eccc5e7e030456d4712241ae1d72282ffa6 --- /dev/null +++ b/tilelang/original/docs/README.md @@ -0,0 +1,30 @@ +# Tile Language Documentation + +The documentation was built upon [Sphinx](https://www.sphinx-doc.org/en/master/). + +## Dependencies + +Run the following command in this directory to install dependencies first: + +```bash +pip3 install -r requirements.txt +``` + +## Build the Documentation + +Then you can build the documentation by running: + +```bash +make html +``` + +## View the Documentation + +Run the following command to start a simple HTTP server: + +```bash +cd _build/html +python3 -m http.server +``` + +Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending ` -p PORT_NUMBER` in the python command above). diff --git a/tilelang/original/docs/_static/custom.css b/tilelang/original/docs/_static/custom.css new file mode 100644 index 0000000000000000000000000000000000000000..0ef6b48cb8b08d17ac582728af5a6f040de06ee1 --- /dev/null +++ b/tilelang/original/docs/_static/custom.css @@ -0,0 +1,11 @@ +/* Reduce the displayed size of the sidebar logo in Furo */ +.sidebar-logo { + max-height: 125px; + width: auto; +} + +/* Optional: keep container from growing too tall due to spacing */ +.sidebar-logo-container { + line-height: 0; +} + diff --git a/tilelang/original/docs/_static/img/LayoutInference.png b/tilelang/original/docs/_static/img/LayoutInference.png new file mode 100644 index 0000000000000000000000000000000000000000..d44e4100d013329365036f355d32f400084456d9 Binary files /dev/null and b/tilelang/original/docs/_static/img/LayoutInference.png differ diff --git a/tilelang/original/docs/_static/img/MatmulExample.png b/tilelang/original/docs/_static/img/MatmulExample.png new file mode 100644 index 0000000000000000000000000000000000000000..555ae30a75b2486bffb8acf27f72802d2c96ec3d Binary files /dev/null and b/tilelang/original/docs/_static/img/MatmulExample.png differ diff --git a/tilelang/original/docs/_static/img/Parallel.png b/tilelang/original/docs/_static/img/Parallel.png new file mode 100644 index 0000000000000000000000000000000000000000..656d4cc01089ccc374d8890c431ee7d9ae096fb5 Binary files /dev/null and b/tilelang/original/docs/_static/img/Parallel.png differ diff --git a/tilelang/original/docs/_static/img/ir_transform_diagram.png b/tilelang/original/docs/_static/img/ir_transform_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..3bd86891394c90db12f98c5dcc43f02330aa3f93 Binary files /dev/null and b/tilelang/original/docs/_static/img/ir_transform_diagram.png differ diff --git a/tilelang/original/docs/_static/img/logo-row.svg b/tilelang/original/docs/_static/img/logo-row.svg new file mode 100644 index 0000000000000000000000000000000000000000..633243f3a9a003a903b859e8d8da5273b0f4cbf3 --- /dev/null +++ b/tilelang/original/docs/_static/img/logo-row.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tilelang/original/docs/_static/img/logo-v2.png b/tilelang/original/docs/_static/img/logo-v2.png new file mode 100644 index 0000000000000000000000000000000000000000..410773f60a0d6ddf9bb86186ecb70529ff1d4667 Binary files /dev/null and b/tilelang/original/docs/_static/img/logo-v2.png differ diff --git a/tilelang/original/docs/_static/img/logo.png b/tilelang/original/docs/_static/img/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..5d04697ce4cd98d1aa6d9edc0f492601ef92c575 Binary files /dev/null and b/tilelang/original/docs/_static/img/logo.png differ diff --git a/tilelang/original/docs/_static/img/mla_hopper/bs128_float16.png b/tilelang/original/docs/_static/img/mla_hopper/bs128_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..3cf24c84b82532bf422efee26afe61b4ae0e1948 Binary files /dev/null and b/tilelang/original/docs/_static/img/mla_hopper/bs128_float16.png differ diff --git a/tilelang/original/docs/_static/img/mla_hopper/bs64_float16.png b/tilelang/original/docs/_static/img/mla_hopper/bs64_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..15807c3d2e57f5a2848b792d0fe746db31be455d Binary files /dev/null and b/tilelang/original/docs/_static/img/mla_hopper/bs64_float16.png differ diff --git a/tilelang/original/docs/_static/img/mla_hopper/pv_layout.jpg b/tilelang/original/docs/_static/img/mla_hopper/pv_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..79b0c8cf301d9c04eef050c893156c71549ce03d Binary files /dev/null and b/tilelang/original/docs/_static/img/mla_hopper/pv_layout.jpg differ diff --git a/tilelang/original/docs/_static/img/mla_hopper/qk_layout.jpg b/tilelang/original/docs/_static/img/mla_hopper/qk_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d5bd923d0d8ab1fe5edece222f31777ccd0d746 Binary files /dev/null and b/tilelang/original/docs/_static/img/mla_hopper/qk_layout.jpg differ diff --git a/tilelang/original/docs/_static/img/op_benchmark_consistent_gemm_fp16.png b/tilelang/original/docs/_static/img/op_benchmark_consistent_gemm_fp16.png new file mode 100644 index 0000000000000000000000000000000000000000..840e423e7199a96e8127cfe2750f7ebb60058bb3 Binary files /dev/null and b/tilelang/original/docs/_static/img/op_benchmark_consistent_gemm_fp16.png differ diff --git a/tilelang/original/docs/_static/img/overview.png b/tilelang/original/docs/_static/img/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..0aa477701b3dbb8eac60988c5e46dbafc8acf8aa Binary files /dev/null and b/tilelang/original/docs/_static/img/overview.png differ diff --git a/tilelang/original/docs/_static/img/software_pipeline_inference.png b/tilelang/original/docs/_static/img/software_pipeline_inference.png new file mode 100644 index 0000000000000000000000000000000000000000..b1b3fd667eb612ea01cafd14c16ecdd599d42c02 Binary files /dev/null and b/tilelang/original/docs/_static/img/software_pipeline_inference.png differ diff --git a/tilelang/original/docs/_static/img/sparse_mma_storage_example.png b/tilelang/original/docs/_static/img/sparse_mma_storage_example.png new file mode 100644 index 0000000000000000000000000000000000000000..0b16398197b28f10b5681cfd27a9f2fe061be5cd Binary files /dev/null and b/tilelang/original/docs/_static/img/sparse_mma_storage_example.png differ diff --git a/tilelang/original/docs/compiler_internals/inject_fence_proxy.md b/tilelang/original/docs/compiler_internals/inject_fence_proxy.md new file mode 100644 index 0000000000000000000000000000000000000000..7a89456ac809d6d534cce9f6167a4e4ba52a9c59 --- /dev/null +++ b/tilelang/original/docs/compiler_internals/inject_fence_proxy.md @@ -0,0 +1,113 @@ +# InjectFenceProxy Pass + +`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations. + +## Why Fences Are Needed + +Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behavior. + +## What the Pass Does + +- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence). +- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization. +- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier. + +The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted. + +### Timeline View + +``` +generic initialize_wgmma_descriptor → generic shared-store → async wgmma + │ │ │ + └─ generic proxy ┴─ generic proxy ┴─ async proxy + │ fence inserted here ↑ + └──────────────────────────────┘ +``` + +The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs. + +## Coverage of Intrinsics + +The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow. + +## Usage + +The pass is part of the default TileLang lowering pipeline. To apply it manually: + +```python +from tilelang import tl +from tvm import IRModule + +mod = IRModule({"main": prim_func}) +with tvm.transform.PassContext(): + mod = tl.transform.InjectFenceProxy()(mod) +``` + +## End-to-End Example + +Before the pass: + +```python +@T.prim_func +def kernel(): + with T.Kernel(1): + desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") + smem = T.decl_buffer((128,), "float16", scope="shared") + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) + smem[0] = T.float16(0) + T.ptx_wgmma_ss( + "float16", + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc.data, + T.int32(0), + desc.data, + T.int32(0), + smem.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) +``` + +After `tl.transform.InjectFenceProxy`: + +```python +@T.prim_func +def kernel(): + with T.Kernel(1): + desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") + smem = T.decl_buffer((128,), "float16", scope="shared") + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) + smem[0] = T.float16(0) + T.fence_proxy_async() + T.ptx_wgmma_ss( + "float16", + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc.data, + T.int32(0), + desc.data, + T.int32(0), + smem.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) +``` + +The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches. + +## Extending the Pass + +If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly. diff --git a/tilelang/original/docs/compiler_internals/letstmt_inline.md b/tilelang/original/docs/compiler_internals/letstmt_inline.md new file mode 100644 index 0000000000000000000000000000000000000000..012af9020d0e228959588cbbdb704ccd5ba34cda --- /dev/null +++ b/tilelang/original/docs/compiler_internals/letstmt_inline.md @@ -0,0 +1,163 @@ +# LetStmt Inlining in TileLang + +This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance. + +## Overview + +A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined. + +## When Does LetStmt Get Inlined? + +The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met: + +### 1. The value satisfies `CanInlineLetStmt` + +The `CanInlineLetStmt` helper returns `true` when: + +- **The value is a constant** (`is_const_number(op->value)` returns true) +- **The value is a variable** (`op->value.as()` returns a node) +- **The value is an integer expression without side effects**: + - The value has `int` dtype + - The side effect level is `kPure` or lower (no observable side effects) + +```cpp +bool CanInlineLetStmt(const LetStmtNode *op) { + if (is_const_number(op->value)) + return true; + if (op->value.as()) + return true; + // Won't face the deep expression explosion problem as in Let expression. + // attempt to inline as much as possible if the value integer type(can be + // index). + if (!op->value.dtype().is_int()) + return false; + return SideEffect(op->value) <= CallEffectKind::kPure; +} +``` + +### 2. The variable is NOT used in buffer definitions + +Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields). + +This protection exists because: +- Buffer definitions are not updated during the simplification pass +- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition +- This would cause compilation errors or incorrect behavior + +The mutator checks this before dropping the binding: + +```cpp +bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); + +if (can_inline && !used_in_buffer_def) { + return body; // Inline: remove LetStmt and return body directly +} +``` + +## Example: Why Buffer Definition Variables Are Protected + +Consider this code: + +```python +let stride = M * 16 +let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1]) +buffer_a[i, j] = ... +``` + +- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects) +- However, `stride` is used in `buffer_a`'s `strides` field +- If we inline it, the buffer definition becomes `strides=[M*16, 1]` +- But the Buffer object's fields are not updated during simplification +- Later code accessing `buffer_a` would fail to find the `stride` variable + +Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined. + +## How Variables Are Collected + +The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions: + +```cpp +void VisitBuffer(const Buffer &buf) { + // Collect variables that should remain defined + VarUseDefAnalyzer usage(Array{}); + usage(buf->data); + for (const auto &dim : buf->shape) { + usage(dim); + } + for (const auto &dim : buf->strides) { + usage(dim); + } + usage(buf->elem_offset); + + // Track for use in LetStmtNode mutator + for (const auto &var : usage.undefined_) { + used_in_buffer_def_.insert(var.get()); + } +} +``` + +## Practical Example: Temporary Variable Issue + +Consider this TileLang code: + +```python +for i in T.Parallel(block_N): + idx = bx * block_N + i + tmp = T.max(A[idx], 1) + B[idx] = tmp / 2 + A[idx] = tmp * 2 +``` + +In this case: +- `tmp` is an integer-like temporary variable +- It satisfies `CanInlineLetStmt` (pure int expression) +- It's **not** used in any buffer definition +- Therefore, `tmp` **will be inlined** + +This means the IR becomes: + +```python +for i in T.Parallel(block_N): + idx = bx * block_N + i + B[idx] = T.max(A[idx], 1) / 2 + A[idx] = T.max(A[idx], 1) * 2 +``` + +If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern. + +## Controlling Let Inlining via Pass Config + +TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments. + +```python +from tilelang import transform +from tilelang.engine.phase import LowerAndLegalize + +with transform.PassContext( + config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True} +): + lowered_mod = LowerAndLegalize(input_mod, target) +``` + +If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it. + +## Summary + +The LetStmt inlining mechanism is a **conservative optimization** that: +1. Aggressively inlines simple, pure integer expressions to simplify the IR +2. Protects variables used in buffer definitions to avoid breaking buffer access +3. Helps reduce IR complexity and improve code generation +4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required + +Understanding when inlining happens is crucial for: +- Debugging compilation issues +- Understanding generated code +- Writing efficient TileLang programs +- Identifying potential optimization opportunities or bugs + +## Related Files + +- `src/transform/simplify.cc`: Main Simplify implementation +- `src/transform/frontend_legalize.cc`: Standalone LetInline pass +- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining +- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass diff --git a/tilelang/original/docs/compiler_internals/tensor_checks.md b/tilelang/original/docs/compiler_internals/tensor_checks.md new file mode 100644 index 0000000000000000000000000000000000000000..b4d2a0b3c03048455b2e2a77e3536292d5f5a202 --- /dev/null +++ b/tilelang/original/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,387 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What you’ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesn’t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. + diff --git a/tilelang/original/docs/conf.py b/tilelang/original/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..877b5582e1e28ee75704d5c75a8ff900a61c4cd3 --- /dev/null +++ b/tilelang/original/docs/conf.py @@ -0,0 +1,79 @@ +# General information about the project. +project = "TileLang
" +author = "Tile Lang Contributors" +copyright = f"2025-2025, {author}" + +# Version information. +with open("../VERSION") as f: + version = f.read().strip() +release = version + +extensions = [ + "sphinx_tabs.tabs", + "sphinx_toolbox.collapse", + "sphinxcontrib.httpdomain", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx_reredirects", + "sphinx.ext.mathjax", + "myst_parser", + "autoapi.extension", +] + +autoapi_type = "python" +autoapi_dirs = ["../tilelang"] + +autoapi_options = [ + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", +] +autoapi_keep_files = False # Useful for debugging the generated rst files + +autoapi_generate_api_docs = True + +autodoc_typehints = "description" + +autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] + +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} + +myst_enable_extensions = ["colon_fence", "deflist"] + +redirects = {"get_started/try_out": "../index.html#getting-started"} + +language = "en" + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md", "**/*libinfo*", "**/*version*"] + +pygments_style = "sphinx" +todo_include_todos = False + +# -- Options for HTML output ---------------------------------------------- + +html_theme = "furo" +templates_path = [] +html_static_path = ["_static"] +html_css_files = ["custom.css"] +footer_copyright = "© 2025-2026 TileLang" +footer_note = " " + +html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"} + +header_links = [ + ("Home", "https://github.com/tile-ai/tilelang"), + ("Github", "https://github.com/tile-ai/tilelang"), +] + +html_context = { + "footer_copyright": footer_copyright, + "footer_note": footer_note, + "header_links": header_links, + "display_github": True, + "github_user": "tile-ai", + "github_repo": "tilelang", + "github_version": "main/docs/", + "theme_vcs_pageview_mode": "edit", +} diff --git a/tilelang/original/docs/deeplearning_operators/deepseek_mla.md b/tilelang/original/docs/deeplearning_operators/deepseek_mla.md new file mode 100644 index 0000000000000000000000000000000000000000..08175778f0cc80c91aa4bf12023bacd6284fa59c --- /dev/null +++ b/tilelang/original/docs/deeplearning_operators/deepseek_mla.md @@ -0,0 +1,200 @@ +# 🚀 Write High Performance FlashMLA with TileLang on Hopper + + +
+ Author: Yu Cheng + Author: Lei Wang +
+ +TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang. + +## Introduction to MLA + +DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance. + +## Benchmark Results + +We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below. + +```{figure} ../_static/img/mla_hopper/bs64_float16.png +:width: 50% +:alt: Overview +:align: center + +Figure 1: Performance under batch size=64 +``` + +```{figure} ../_static/img/mla_hopper/bs128_float16.png +:width: 50% +:alt: Overview +:align: center + +Figure 2: Performance under batch size=128 +``` + +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. + +## Implementation + +First, let's review the core computation logic of traditional FlashAttention: + +```python +# acc_s: [block_M, block_N] +# scores_max: [block_M] +# scores_scale: [block_M] +# acc_o: [block_M, dim] + +for i in range(loop_range): + acc_s = Q @ K[i] + scores_max_prev = scores_max + scores_max = max(acc_s, dim=1) + scores_scale = exp(scores_max_prev - scores_max) + acc_o *= scores_scale + acc_s = exp(acc_s - scores_max) + acc_o = acc_s @ V[i] + ... +``` + +Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency. + +Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance. + +This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. + +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. + +Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. + +### Layout Inference + +While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you. + +Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations. + +```{figure} ../_static/img/mla_hopper/qk_layout.jpg +:width: 50% +:alt: Overview +:align: center + +Figure 3: Buffer shapes in Q @ K +``` + +```{figure} ../_static/img/mla_hopper/pv_layout.jpg +:width: 50% +:alt: Overview +:align: center + +Figure 4: Buffer shapes in acc_s @ V +``` + +The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA. + +For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`. + +It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance. + +### Threadblock Swizzling + +Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions. + +In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code: + +```python +T.use_swizzle(panel_size: int, order: str = "row") +``` + +Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". + + +### Shared Memory Swizzling + +In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. + +One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency. + +Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code: + +```python +T.annotate_layout({ + S_shared: TileLang.layout.make_swizzled_layout(S_shared), +}) +``` + +Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. + + +### Warp-Specialization + +The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. + +In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. + + +### Pipeline + + +Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: + +```python +T.pipelined(range: int, stage: int) +``` + +Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. + + +### Split-KV + +We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. + +In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. + + +## 🚀 On AMD MI300X Accelerators + +Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms. + +### Architectural Considerations and Optimization Strategies + +Key implementation differences between Hopper and MI300X architectures include: + +1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations. + +2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes: + - Reducing software pipeline stages + - Register-based caching of Q matrices instead of shared memory utilization: + ```python + # Original shared memory allocation + Q_shared = T.alloc_shared([block_H, dim], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + + # Optimized register allocation + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + ``` + +3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64. + +4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code. + +### Performance Evaluation + +We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate: + +
+ + AMD FlashMLA Performance Comparison + +
Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)
+
+ +Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) in most test cases, while significantly outperforming both Triton (1.98×) and PyTorch (3.76×) implementations. This performance is achieved through a concise 80-line Python implementation, demonstrating TileLang's efficiency and programmability advantages. + +### Future Optimization Opportunities + +1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction. + +2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to: + - Reduce shared memory pressure + - Improve compute-to-memory access ratios + - Enhance parallelism through dimension-wise task distribution diff --git a/tilelang/original/docs/deeplearning_operators/elementwise.md b/tilelang/original/docs/deeplearning_operators/elementwise.md new file mode 100644 index 0000000000000000000000000000000000000000..f3543c02f5ed4a7d95708de488dba0309ca7bf93 --- /dev/null +++ b/tilelang/original/docs/deeplearning_operators/elementwise.md @@ -0,0 +1,332 @@ +# ElementWise Operators + +
+ Author: Chenghua Wang +
+ +:::{warning} +:class: myclass1 myclass2 +:name: a-tip-reference + + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +Elementwise operators are widely used in deep learning and often serve as the first example encountered by those beginning to explore parallel programming. This tutorial will analyze several implementations of the elementwise addition operator using TileLang and compare them with the corresponding CUDA implementation. By the end of this tutorial, you will learn: + +1. How to implement an elementwise operator using TileLang. +2. How to compile operators with dynamic shapes. +3. How TileLang addresses boundary-related issues. +4. The similarities and differences between operators implemented in TileLang and those implemented in CUDA/CuTe. + +Please note that this tutorial does not delve deeply into the design principles of TileLang. For a broader understanding of TileLang, we recommend consulting the [Overview](../get_started/overview.md). + +## Elementwise add in TileLang + +```python +def elementwise_add(N, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads), threads=threads) as (b_x): + # vector add. + for i in T.Parallel(threads): + C[b_x * threads + i] = A[b_x * threads + i] + B[b_x * threads + i] + + return main +``` + +All logic for TileLang kernels must be implemented within the `T.Kernel(...)` scope. In this example, initializing `T.kernel(...)` requires specifying both the grid size and the number of threads per block. The returned value `bx` corresponds to `blockIdx.x` in CUDA. In the provided implementation, `T.Parallel` is used to process the data tile (of size `1 x threads`) assigned to the block for computation. + +Those familiar with CUDA programming might wonder where `threadIdx` fits into this. Note that the code inside `T.Kernel` operates at the **block level**, not the **thread level**. In this example, your focus is solely on defining the block-level logic. During compilation, TileLang automatically maps computations to the corresponding threads and applies further optimizations. The optimized code generated by TileLang may closely align with carefully handcrafted computational logic, as demonstrated in Section 2 with a concrete example. While TileLang also supports thread-level programming semantics, this will be covered in subsequent discussions. + +The program can be compiled using the following code: + +```python +program = elementwise_add(1024, threads=256, dtype=T.bfloat16) +kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") +``` +Launching the kernel is straightforward, just call it directly like a function: + +```python +C = kernel(A, B) +``` + +The vector add operation can also be extended to two-dimensional cases, where both implementations demonstrate comparable efficiency in practice. Below is an example from the test section that readers can refer to: [example](https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py). The code for this kernel is provided below: + +```python +import tilelang.language as T +def elementwise_add( + M, + N, + block_M, + block_N, + in_dtype, + out_dtype, + threads, +): + @T.prim_func + def main( + A: T.Tensor((M, N), in_dtype), + B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + + C[y, x] = A[y, x] + B[y, x] + + return main +``` + +### How to compile operators with dynamic shapes? + +In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: + +```python +program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16) +kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") +``` + +The resulting CUDA code for the kernel will include an additional `int N` parameter after the `bfloat16_t* __restrict__ A`, `bfloat16_t* __restrict__ B`, and `bfloat16_t* __restrict__ C` parameters. + +### How TileLang addresses boundary-related issues. + +TileLang automatically incorporates boundary-checking conditions; however, this comes at a cost. These boundary conditions may prevent TileLang from performing more advanced optimizations. I will introduce an example from the next section in advance. The corresponding code is also provided below, but note that it involves the associated CUDA code. Readers are encouraged to first review the next section before returning to this paragraph for a clearer understanding. + +When compiling the example below, let's set `N` to 2047: + +```python +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x): + # vector add. + for i, j in T.Parallel(threads, num_per_thread): + offsets = (b_x * threads + i) * num_per_thread + C[offsets + j] = A[offsets + j] + B[offsets + j] + + return main +``` + +TileLang will generate the following CUDA code: + +```c++ +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + #pragma unroll + for (int i = 0; i < 8; ++i) { + if (((i * 256) + ((int)threadIdx.x)) < 2047) { + C[((i * 256) + ((int)threadIdx.x))] = (A[((i * 256) + ((int)threadIdx.x))] + B[((i * 256) + ((int)threadIdx.x))]); + } + } +} +``` + +We can observe that TileLang did not apply optimizations such as vectorization or coalesced memory access. In fact, except for the tail group of data, all other threads could have executed more optimized code. + +## Comparison of TileLang, CUDA, and CuTe + +For the subsequent examples, this tutorial will use the vector add operation for simplicity and brevity. + +Typically, those new to CUDA programming often write CUDA code in a style similar to this: + +```c++ +// vector add +__global__ void elementwise_add(float* a, float* b, float* c, int N) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < N) { + c[idx] = a[idx] + b[idx]; + } +} +``` + +The code above assigns each thread to compute a single element, which is evidently inefficient since common acceleration techniques like coalesced memory access and vectorization are not utilized. However, TileLang code written with similar logic (e.g., loop-based traversal) can be optimized by the compiler into highly efficient implementations, making it more accessible for beginners. Additionally, the final generated code from the compiler remains observable, providing transparency into the optimization process. + +The CUDA code generated by TileLang for the compiled kernel can be retrieved using the `kernel.get_kernel_source()` method. Below is the CUDA code produced for the vector addition example from Section 1: + +```cu +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + if (((int)threadIdx.x) < 32) { + uint4 __1; + uint4 v_ = *(uint4*)(A + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))); + uint4 v__1 = *(uint4*)(B + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))); + ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x); + ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y); + ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x); + ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y); + ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x); + ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y); + ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x); + ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y); + *(uint4*)(C + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))) = __1; + } +} +``` + +In the code above, TileLang not only automatically maps block-level parallelism to threads but also applies optimizations such as vectorization and coalesced memory access. + +While TileLang incorporates various optimizations for the aforementioned case, its behavior may sometimes appear counterintuitive. For example, when targeting 256 threads for task processing, applying vectorization can result in each thread computing 8 data elements—effectively utilizing only 32 active threads. Interestingly, the kernel launch configuration still retains the original allocation of 256 threads. + +In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design. + +```python +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x): + # vector add. + for i, j in T.Parallel(threads, num_per_thread): + offsets = (b_x * threads + i) * num_per_thread + C[offsets + j] = A[offsets + j] + B[offsets + j] + + return main +``` + +The corresponding CUDA code generated for the above example is presented below: + +```c++ +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + uint4 __1; + uint4 v_ = *(uint4*)(A + (((int)threadIdx.x) * 8)); + uint4 v__1 = *(uint4*)(B + (((int)threadIdx.x) * 8)); + ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x); + ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y); + ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x); + ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y); + ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x); + ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y); + ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x); + ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y); + *(uint4*)(C + (((int)threadIdx.x) * 8)) = __1; +} +``` +Aha, this CUDA code aligns closely with conventional programming practices, making it more familiar and intuitive. + +But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations. + +```python +def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads * NUM_ELE_PER_THREAD), threads=threads) as (b_x): + A_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype) + B_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype) + C_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype) + + s_start = b_x * threads * NUM_ELE_PER_THREAD + s_end = (b_x + 1) * threads * NUM_ELE_PER_THREAD + + # LDG. 128 + T.copy( + A[s_start:s_end], + A_register, + ) + T.copy( + B[s_start:s_end], + B_register, + ) + + # vector add. + for tid, i in T.Parallel(threads, NUM_ELE_PER_THREAD): + C_register[tid * NUM_ELE_PER_THREAD + i] = ( + A_register[tid * NUM_ELE_PER_THREAD + i] + + B_register[tid * NUM_ELE_PER_THREAD + i]) + + # STG. 128 + T.copy( + C_register, + C[s_start:s_end], + ) + + return main +``` + +In the example above, each thread is responsible for computing 8 elements. The `T.copy(...)` method functions at the block level, and TileLang automatically maps data movement operations to individual threads. This design may resonate more intuitively with CUDA developers. Let us now analyze the CUDA code generated from this implementation. + +```c++ +// N is set to 8192 * 8192 when compiling +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + bfloat16_t A_register[8]; + bfloat16_t B_register[8]; + *(uint4*)(A_register + 0) = *(uint4*)(A + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))); + *(uint4*)(B_register + 0) = *(uint4*)(B + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))); + uint4 __1; + uint4 v_ = *(uint4*)(A_register + 0); + uint4 v__1 = *(uint4*)(B_register + 0); + ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x); + ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y); + ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x); + ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y); + ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x); + ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y); + ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x); + ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y); + *(uint4*)(A_register + 0) = __1; + *(uint4*)(C + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))) = *(uint4*)(A_register + 0); +} +``` + +We observed the emergence of two additional registers, `A_register` and `B_register`. However, during the actual computation, these registers are simply reassigned to `v_` and `v__1`, respectively. + +To evaluate complexity, one could implement the same elementwise addition operator using CuTe and compare it with the TileLang version. The corresponding CuTe code is provided below: + +```c++ +template +__global__ void elementwise_add(nv_bfloat16* C, + const nv_bfloat16* A, + const nv_bfloat16* B, + int N) { + using namespace cute; + + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + Tensor t_C = make_tensor(make_gmem_ptr(C), make_shape(N)); + Tensor t_A = make_tensor(make_gmem_ptr(A), make_shape(N)); + Tensor t_B = make_tensor(make_gmem_ptr(B), make_shape(N)); + + Tensor t_C_tile = local_tile(t_C, make_shape(Int{}), make_coord(idx)); + Tensor t_A_tile = local_tile(t_A, make_shape(Int{}), make_coord(idx)); + Tensor t_B_tile = local_tile(t_B, make_shape(Int{}), make_coord(idx)); + + Tensor reg_buffer_A = make_tensor_like(t_A_tile); + Tensor reg_buffer_B = make_tensor_like(t_B_tile); + Tensor reg_buffer_C = make_tensor_like(t_C_tile); + + // LDG. 128 + copy(t_A_tile, reg_buffer_A); + copy(t_B_tile, reg_buffer_B); + + auto reg_C_vector = recast(reg_buffer_C); + auto reg_A_vector = recast(reg_buffer_A); + auto reg_B_vector = recast(reg_buffer_B); + + // Perform vectorized addition +#pragma unroll + for (int vec_idx = 0; vec_idx < size(reg_C_vector); ++vec_idx) { + reg_C_vector(vec_idx) = reg_A_vector(vec_idx) + reg_B_vector(vec_idx); + } + + auto reg_C_flat = recast(reg_C_vector); + + // STG. 128 + copy(reg_C_flat, t_C_tile); +} +``` + +## Conclusion + +This tutorial showcases the implementation of the elementwise addition operator using TileLang, while also comparing various design approaches. TileLang significantly reduces the complexity of CUDA programming, enabling high performance with minimal code. Nevertheless, working with TileLang demands careful attention to specific implementation details. To ensure computational efficiency, it is essential to thoroughly examine the generated CUDA kernels. + +--- + +**Reference:** + +[1] The CuTe code implementation draws inspiration from the techniques discussed in this blog: https://zhuanlan.zhihu.com/p/690703999 diff --git a/tilelang/original/docs/deeplearning_operators/gemv.md b/tilelang/original/docs/deeplearning_operators/gemv.md new file mode 100644 index 0000000000000000000000000000000000000000..c75a961b8079b75d4a813658b1cae1899a873353 --- /dev/null +++ b/tilelang/original/docs/deeplearning_operators/gemv.md @@ -0,0 +1,464 @@ +# General Matrix-Vector Multiplication (GEMV) +=========================================== + +
+ Contributor: @botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +Example code can be found at `examples/gemv/example_gemv.py`. +::: + +General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`. + +## Triton Implementation +When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`. + +A simple Triton kernel for GEMV might look like this: +```python +@triton.jit +def _gemv_naive( + x_ptr, A_ptr, y_ptr, + N, K, + BLOCK_SIZE_K: tl.constexpr, +): + n = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_SIZE_K) + mask = offs_k < K + a_ptrs = A_ptr + n * K + offs_k + a_vals = tl.load(a_ptrs, mask=mask, other=0.0) + x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0) + dot = tl.sum(a_vals * x_vals, axis=0) + tl.store(y_ptr + n, dot) +``` + +`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control. + +## Naive Implementation in TileLang +If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example: + +```python +def naive_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: + tn = T.get_thread_binding(0) # tn = threadIdx.x + A_shared = T.alloc_shared((BLOCK_K,), dtype) + B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype) + C_reg = T.alloc_local((1,), accum_dtype) + T.clear(C_reg) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for tk in T.serial(BLOCK_K): + A_shared[tk] = A[bk * BLOCK_K + tk] + B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + for tk in T.serial(BLOCK_K): + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, + tk].astype(accum_dtype) + C[bn * BLOCK_N + tn] = C_reg[0] + + return main +``` + +And your kernel will be compiled into CUDA by `TileLang` (in `~/.tilelang/cache`): + +```C++ +extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + float C_reg[1]; + __shared__ uint64_t _mbarrier[2]; + if (((int)threadIdx.x) == 0) { + tl::mbarrier_init(_mbarrier[0], 128); + tl::mbarrier_init(_mbarrier[1], 128); + } + __syncthreads(); + if (128 <= ((int)threadIdx.x)) { + tl::warpgroup_reg_dealloc<24>(); + for (int bk = 0; bk < 8; ++bk) { + tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1)); + for (int tk = 0; tk < 128; ++tk) { + ((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)]; + ((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)]; + } + tl::fence_proxy_async(); + tl::mbarrier_cp_async_arrive(_mbarrier[0]); + tl::mbarrier_arrive(_mbarrier[0]); + } + } else { + tl::warpgroup_reg_alloc<240>(); + C_reg[0] = 0.000000e+00f; + for (int bk_1 = 0; bk_1 < 8; ++bk_1) { + tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1)); + for (int tk_1 = 0; tk_1 < 128; ++tk_1) { + C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)]))); + } + tl::fence_proxy_async(); + tl::mbarrier_arrive(_mbarrier[1]); + } + C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]); + } +} +``` + +In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block). + +At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower. + +## More Concurrency + +To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA. + +Here’s a simplified version: +```python +def naive_splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((1,), dtype) + B_local = T.alloc_local((1,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + A_local[0] = A[bk * BLOCK_K + tk] + B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main +``` + +By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS. + +### Customizing Parallelism in K Dimension +If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration: + +```python +def splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + reduce_threads: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + TILE_K = T.ceildiv(BLOCK_K, reduce_threads) + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.serial(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main +``` + + +## Vectorized Reads + +GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`: + +```python +def splitk_gemv_vectorized( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main +``` + +With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance. + + +## `tvm_thread_allreduce` Instead of `atomicAdd` + +[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`: + +```python +def splitk_gemv_vectorized_tvm( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + )) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main +``` + +With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS! + +## Autotune + +`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations: + +```python +def get_best_config(N, K): + + def get_configs(): + BLOCK_N = [2, 4, 8, 32, 64, 128] + reduce_threads = [4, 8, 32] + _configs = list(itertools.product( + BLOCK_N, + reduce_threads, + )) + configs = [{ + "BLOCK_N": c[0], + "reduce_threads": c[1], + } for c in _configs] + return configs + + @autotune( + configs=get_configs(), + warmup=3, + rep=20, + ) + @jit( + out_idx=[-1], + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + target="auto", + ) + def kernel( + BLOCK_N=None, + reduce_threads=None, + ): + dtype = "float16" + accum_dtype = "float" + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + )) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main + + return kernel() +``` + +After autotuning, now our kernel gets **~0.0067 ms**, the final generated CUDA kernel might like this: + +```C++ +extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) { + float C_accum[1]; + half_t A_local[8]; + half_t B_local[8]; + __shared__ float red_buf0[64]; + C_accum[0] = 0.000000e+00f; + for (int bk = 0; bk < 4; ++bk) { + *(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8))); + *(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8))); + for (int k = 0; k < 8; ++k) { + C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k]))); + } + } + tl::fence_proxy_async(); + __syncthreads(); + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0]; + __syncthreads(); + if (((int)threadIdx.y) < 16) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 8) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 4) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 2) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 1) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]); + } + __syncthreads(); + C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]); +} +``` + +This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically. + +## Conclusion + +### Benchmark Table on Hopper GPU + +| Kernel Name | Latency | +|------------|------------| +| torch/cuBLAS | 0.00784 ms | +| Triton | 0.00773 ms | +| naive_gemv | 0.16607 ms | +| splitk_gemv | 0.02419 ms | +| splitk_gemv_vectorized | 0.00809 ms | +| splitk_gemv_vectorized_tvm | 0.00675 ms | + + +Triton Time: 0.0077344514429569244 +In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives. \ No newline at end of file diff --git a/tilelang/original/docs/deeplearning_operators/matmul.md b/tilelang/original/docs/deeplearning_operators/matmul.md new file mode 100644 index 0000000000000000000000000000000000000000..fea036ebe48429d8ce40b46a9f5220f5e2d4e828 --- /dev/null +++ b/tilelang/original/docs/deeplearning_operators/matmul.md @@ -0,0 +1,259 @@ +# General Matrix-Matrix Multiplication with Tile Library + +
+ Author: Lei Wang +
+ +:::{warning} +:class: myclass1 myclass2 +:name: a-tip-reference + + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction: + +* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM. + +* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc. + +* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc. + +```{figure} ../_static/img/overview.png +:width: 50% +:alt: Overview +:align: center + +Figure 1: High-level overview of the TileLang compilation flow. +``` + +In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance. + +## Why Another GPU DSL? + +TileLang emerged from the need for a DSL that: + +1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed. +2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.). +3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning. + +While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints. + +## Matrix Multiplication Example + +```{figure} ../_static/img/MatmulExample.png +:alt: Matmul Example +:align: center + +``` + +### Basic Structure + +Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses: + +* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.). +* **`T.alloc_shared(...)`** to allocate GPU shared memory. +* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation. +* **`T.Pipelined(...)`** to express software pipelining across the K dimension. +* **`T.Parallel(...)`** to parallelize data copy loops. +* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs). + +```python +import tilelang +import tilelang.language as T +from tilelang.intrinsics import make_mma_swizzle_layout + +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Optional layout hints (commented out by default) + # T.annotate_layout({ + # A_shared: make_mma_swizzle_layout(A_shared), + # B_shared: make_mma_swizzle_layout(B_shared), + # }) + + # Optional: Enabling swizzle-based rasterization + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A from global to shared memory + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Parallel copy tile of B from global to shared memory + for k, j in T.Parallel(block_K, block_N): + B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] + + # Perform a tile-level GEMM + T.gemm(A_shared, B_shared, C_local) + + # Copy result from local (register fragment) to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + +# 1. Create the TileLang function +func = matmul(1024, 1024, 1024, 128, 128, 32) + +# 2. JIT-compile the kernel for NVIDIA GPU +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda") + +import torch + +# 3. Prepare input tensors in PyTorch +a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) +b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) + +# 4. Invoke the JIT-compiled kernel +c = jit_kernel(a, b) +ref_c = a @ b + +# 5. Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 6. Inspect generated CUDA code (optional) +cuda_source = jit_kernel.get_kernel_source() +print("Generated CUDA kernel:\n", cuda_source) + +# 7. Profile performance +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +``` + +### Key Concepts + +1. **Kernel Context**: + +```python +with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + ... +``` + +- This sets up the block grid dimensions based on N/block_N and M/block_M. +- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads. + + +```{figure} ../_static/img/Parallel.png +:alt: Parallel +:align: center + +``` + + +2. **Shared & Fragment Memory**: + +```python +A_shared = T.alloc_shared((block_M, block_K), dtype) +B_shared = T.alloc_shared((block_K, block_N), dtype) +C_local = T.alloc_fragment((block_M, block_N), accum_dtype) +``` + +- `T.alloc_shared` allocates shared memory across the entire thread block. +- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread. + +3. **Software Pipelining**: + +```python +for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + ... +``` + +- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic. +- The argument `num_stages=3` indicates the pipeline depth. + +```{figure} ../_static/img/software_pipeline_inference.png +:alt: Software Pipeline Inference +:align: center + +``` + + +4. **Parallel Copy**: + +```python +for k, j in T.Parallel(block_K, block_N): + B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] +``` + +- `T.Parallel` marks the loop for thread-level parallelization. +- The compiler will map these loops to the available threads in the block. + +5. **Tile-Level GEMM**: + +```python +T.gemm(A_shared, B_shared, C_local) +``` + +- A single call that performs a tile-level matrix multiplication using the specified buffers. +- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach. + +6. **Copying Back Results**: + +```python +T.copy(C_local, C[by * block_M, bx * block_N]) +``` + +- After computation, data in the local register fragment is written back to global memory. + +## Comparison with Other DSLs + +TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also: + +- Allows explicit memory layout annotations (e.g. `make_mma_swizzle_layout`). +- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined. +- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2. + +## Performance on Different Platforms + +```{figure} ../_static/img/op_benchmark_consistent_gemm_fp16.png +:alt: Performance on Different Platforms +:align: center + +``` + +When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown: + +- ~1.1x speedup over cuBLAS on RTX 4090 +- ~0.97x on A100 (on par with cuBLAS) +- ~1.0x on H100 +- ~1.04x on MI300X +- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware. + +These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities. + +## Conclusion + +This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code: + +1. We allocated shared memory and register fragments. +2. We pipelined the loading and computation along the K dimension. +3. We used parallel copying to efficiently load tiles from global memory. +4. We invoked `T.gemm` to dispatch a tile-level matrix multiply. +5. We verified correctness against PyTorch and examined performance. + +By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs. + +For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1. + +## Further Resources + +* [TileLang GitHub](https://github.com/tile-ai/tilelang) +* [BitBLAS](https://github.com/tile-ai/bitblas) +* [Triton](https://github.com/openai/triton) +* [Cutlass](https://github.com/NVIDIA/cutlass) +* [PyCUDA](https://documen.tician.de/pycuda/) diff --git a/tilelang/original/docs/deeplearning_operators/matmul_sparse.md b/tilelang/original/docs/deeplearning_operators/matmul_sparse.md new file mode 100644 index 0000000000000000000000000000000000000000..5910bd6f8c25943ee18bbd65b7ed7fa0b060de5a --- /dev/null +++ b/tilelang/original/docs/deeplearning_operators/matmul_sparse.md @@ -0,0 +1,262 @@ +# Sparse Matrix-Matrix Multiplication with Tile Library + +
+ Author: botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + + This feature is still **experimental** and need further optimization. + + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +It's suggested to go through `docs/deeplearning_operators/matmul.md` first. + +Example code can be found at `examples/gemm_sp`. +::: + +## Structured sparsity in the NVIDIA Ampere architecture + +Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation. + +:::{warning} + This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X. +::: + +```{figure} ../_static/img/sparse_mma_storage_example.png +:align: center + +Figure: Sparse MMA storage example (from PTX doc) +``` + +## Compress a dense tensor + +To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata. + +Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`). + +A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression. + +```python +from tilelang.utils.sparse import compress +A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) +``` + +Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. + +> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor) +The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads). +For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**. + + +## `T.gemm_sp` with CUTLASS's compressor + +:::{warning} + +It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time. + +::: + +A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata. + +Check comments in below kernel code for required modification. + +```python +def matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + is_8_bit = "8" in in_dtype + metadata_dtype = 'int32' if is_8_bit else 'int16' + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ # Annotate reordered cutlass metadata layout + E: + make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main +``` + +Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`. + +## `T.gemm_sp_v2` with a custom compressor + +To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`. + +Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors. + +The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs. + +Suppose we have the following row vector: +```python +t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten() +``` + +The non-zero elements and their corresponding indices are: + +```python +t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten() +indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten() +``` + +The corresponding uint16 metadata is: +```python +# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000]) +# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16) +# Note: the above code is not runnable in python as the interpreter won't take the binary +# as 2's complement +metadata_int16 = tensor(-29107) +``` + +You can decode an int16 metadata tensor using the following utility: +```python +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) +``` + +The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level. + +For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function. + +If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel. + +```python + +@tilelang.jit(out_idx=[1, 2], pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, +}) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: # NOTE: Make sure compressor metadata layout + T.annotate_layout({ # is same with your computation kernel + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, + mma_dtype="float16", + arch="8.0", + block_k=block_K), + }) + T.clear(A_sp_shared) + T.clear(E_shared) + non_zero_cnt = T.alloc_local((1, ), dtype="uint8") + non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + T.clear(non_zero_cnt) + T.clear(non_zero_elt_log_idx) + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel +``` + +## A note on `gemm_sp` and `gemm_sp_v2` + +Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout. + +However, fixing a specific layout introduces several potential issues: + +1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling. + +2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically. + +3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) + +`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout. \ No newline at end of file diff --git a/tilelang/original/docs/get_started/Installation.md b/tilelang/original/docs/get_started/Installation.md new file mode 100644 index 0000000000000000000000000000000000000000..8fa41c023ad82fcf6004b230c9b556f87aaa32a4 --- /dev/null +++ b/tilelang/original/docs/get_started/Installation.md @@ -0,0 +1,260 @@ +# Installation Guide + +## Installing with pip + +**Prerequisites for installation via wheel or PyPI:** + +- **glibc**: 2.28 (Ubuntu 20.04 or later) +- **Python Version**: >= 3.8 +- **CUDA Version**: 12.0 <= CUDA < 13 + +The easiest way to install tilelang is directly from PyPI using pip. To install the latest version, run the following command in your terminal: + +```bash +pip install tilelang +``` + +Alternatively, you may choose to install tilelang using prebuilt packages available on the Release Page: + +```bash +pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl +``` + +To install the latest version of tilelang from the GitHub repository, you can run the following command: + +```bash +pip install git+https://github.com/tile-ai/tilelang.git +``` + +After installing tilelang, you can verify the installation by running: + +```bash +python -c "import tilelang; print(tilelang.__version__)" +``` + +## Building from Source + +**Prerequisites for building from source:** + +- **Operating System**: Linux +- **Python Version**: >= 3.8 +- **CUDA Version**: >= 10.0 + +If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment. + +First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands: + +```bash +apt-get update +apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev +``` + +Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process. + +> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies). + +```bash +git clone --recursive https://github.com/tile-ai/tilelang.git +cd tilelang +pip install . -v +``` + +If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation. + +```bash +pip install -e . -v +``` + +> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below. + +(working-from-source-via-pythonpath)= + +### Working from Source via `PYTHONPATH` (Recommended for Developers) + +If you prefer to work directly from the source tree via `PYTHONPATH` instead of using pip, make sure the native extension (`libtilelang.so`) is built first: + +```bash +mkdir -p build +cd build +cmake .. -DUSE_CUDA=ON +make -j +``` + +We also recommend using `ninja` to speed up compilation: + +```bash +cmake .. -DUSE_CUDA=ON -G Ninja +ninja +``` + +Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example: + +```bash +export PYTHONPATH=/path/to/tilelang:$PYTHONPATH +python -c "import tilelang; print(tilelang.__version__)" +``` + +Some useful CMake options you can toggle while configuring: +- `-DUSE_CUDA=ON|OFF` builds against NVIDIA CUDA (default ON when CUDA headers are found). +- `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs. +- `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`. + +(using-existing-tvm)= + +### Building with Customized TVM Path + +If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang: + +```bash +TVM_ROOT= pip install . -v +``` + +> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly. + +(install-using-docker)= + +## Install Using Docker + +For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems. + +**Prerequisites:** +- Docker installed on your system +- NVIDIA Docker runtime or GPU is not necessary for building tilelang, you can build on a host without GPU and use that built image on other machine. + +1. **Clone the Repository**: + +```bash +git clone --recursive https://github.com/tile-ai/tilelang +cd tilelang +``` + +2. **Build Docker Image**: + +Navigate to the docker directory and build the image for your desired CUDA version: + +```bash +cd docker +docker build -f Dockerfile.cu120 -t tilelang-cu120 . +``` + +Available Dockerfiles: +- `Dockerfile.cu120` - For CUDA 12.0 +- Other CUDA versions may be available in the docker directory + +3. **Run Docker Container**: + +Start the container with GPU access and volume mounting: + +```bash +docker run -itd \ + --shm-size 32g \ + --gpus all \ + -v /home/tilelang:/home/tilelang \ + --name tilelang_b200 \ + tilelang-cu120 \ + /bin/zsh +``` + +**Command Parameters Explanation:** +- `--shm-size 32g`: Increases shared memory size for better performance +- `--gpus all`: Enables access to all available GPUs +- `-v /home/tilelang:/home/tilelang`: Mounts host directory to container (adjust path as needed) +- `--name tilelang_b200`: Assigns a name to the container for easy management +- `/bin/zsh`: Uses zsh as the default shell + +4. **Access the Container and Verify Installation**: + +```bash +docker exec -it tilelang_b200 /bin/zsh +# Inside the container: +python -c "import tilelang; print(tilelang.__version__)" +``` + +## Install with Nightly Version + +For users who want access to the latest features and improvements before official releases, we provide nightly builds of tilelang. + +```bash +pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/ +# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/ +``` + +> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet. + +## Install Configs + +### Build-time environment variables +`USE_CUDA`: If to enable CUDA support, default: `ON` on Linux, set to `OFF` to build a CPU version. By default, we'll use `/usr/local/cuda` for building tilelang. Set `CUDAToolkit_ROOT` to use different cuda toolkit. + +`USE_ROCM`: If to enable ROCm support, default: `OFF`. If your ROCm SDK does not located in `/opt/rocm`, set `USE_ROCM=` to enable build ROCm against custom sdk path. + +`USE_METAL`: If to enable Metal support, default: `ON` on Darwin. + +`TVM_ROOT`: TVM source root to use. + +`NO_VERSION_LABEL` and `NO_TOOLCHAIN_VERSION`: +When building tilelang, we'll try to embed SDK and version information into package version as below, +where local version label could look like `.git`. Set `NO_VERSION_LABEL=ON` to disable this behavior. +``` +$ python -mbuild -w +... +Successfully built tilelang-0.1.6.post1+cu116.git0d4a74be-cp38-abi3-linux_x86_64.whl +``` + +where `={cuda,rocm,metal}`. Specifically, when `=cuda` and `CUDA_VERSION` is provided via env, +`=cu`, similar with this part in pytorch. +Set `NO_TOOLCHAIN_VERSION=ON` to disable this. + +### Run-time environment variables + +Please refer to the `env.py` file for a full list of supported run-time environment variables. + +## Other Tips + +### IDE Configs + +Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir. +VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration. + +### Compile Cache + +The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found. + +### Repairing Wheels + +If you plan to use your wheel in other environment, +it's recommended to use auditwheel (on Linux) or delocate (on Darwin) +to repair them. + +(faster-rebuild-for-developers)= + +### Faster Rebuild for Developers + +`pip install` introduces extra [un]packaging and takes ~30 sec to complete, +even if no source change. + +Developers who needs to recompile frequently could use: + +```bash +pip install -r requirements-dev.txt + +# For first time compilation +pip install -e . -v --no-build-isolation + +# Or manually compile with cmake/ninja. Remember to set PYTHONPATH properly. +mkdir build +cd build +cmake .. -G Ninja +ninja + +# Rebuild when you change the cpp code +cd build; ninja +``` + +When running in editable/developer mode, +you'll see logs like below: + +```console +$ python -c 'import tilelang' +2025-10-14 11:11:29 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /Users/yyc/repo/tilelang/build +``` diff --git a/tilelang/original/docs/get_started/Installation_dcu.md b/tilelang/original/docs/get_started/Installation_dcu.md new file mode 100644 index 0000000000000000000000000000000000000000..d1623de0bf4ccf7e4636617dac9c6e3232d610d4 --- /dev/null +++ b/tilelang/original/docs/get_started/Installation_dcu.md @@ -0,0 +1,42 @@ +# Installation for DCU +## Building from Source +```bash +mkdir -p build +cd build +cmake .. -DUSE_CUDA=OFF -DUSE_ROCM=ON +make -j +``` + +```bash +export PYTHONPATH=/path/to/tilelang:$PYTHONPATH +python -c "import tilelang; print(tilelang.__version__)" +``` + +## Other Tips +### Missing tvm_ffi Module +If you encounter the error ModuleNotFoundError: No module named 'tvm_ffi', it means the TVM foreign function interface package was not installed. This often happens if the submodules were built manually. Fix it by running: +``` +# Navigate to the tvm_ffi directory +cd 3rdparty/tvm/3rdparty/tvm_ffi + +# Install the package in editable mode +pip install . + +# Return to the project root +cd ../../../.. +``` +### DTK Path Configuration +If you encounter errors related to DTK path detection (e.g., hipcc not found or failure to retrieve GPU architecture), you may need to manually specify the DTK installation path in the source code. +Locate the file tilelang/contrib/rocm.py and modify the default value of the rocm_path parameter in the get_rocm_arch function (around line 231): + +``` +# File: tilelang/contrib/rocm.py + +# Change from: +def get_rocm_arch(rocm_path="/opt/rocm"): + ... + +# To (for Hygon DCU environments): +def get_rocm_arch(rocm_path="/opt/dtk"): + ... +``` \ No newline at end of file diff --git a/tilelang/original/docs/get_started/overview.md b/tilelang/original/docs/get_started/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..18fa9f1936fcb5f9b5dedb9efd394992acf243f6 --- /dev/null +++ b/tilelang/original/docs/get_started/overview.md @@ -0,0 +1,91 @@ +# The Tile Language: A Brief Introduction + +## Programming Interface + +The figure below depicts how **TileLang** programs are progressively lowered from a high-level description to hardware-specific executables. We provide three different programming interfaces—targeted at **Beginner**, **Developer**, and **Expert** users—that each reside at different levels in this lowering pipeline. The **Tile Language** also allows mixing these interfaces within the same kernel, enabling users to work at whichever level of abstraction best suits their needs. + +```{figure} ../_static/img/overview.png +:width: 50% +:alt: Overview +:align: center + +Figure 1: High-level overview of the TileLang compilation flow. +``` + +## Programming Interfaces + +1. **Beginner Level (Hardware-Unaware)** + - Intended for users who need to write code that is independent of specific hardware details. + - The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations. + - *Note:* This interface is not yet fully implemented. + +2. **Developer Level (Hardware-Aware with Tile Library)** + - Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations. + - Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures. + - Users at this level can leverage these ready-made primitives without diving into low-level threading details. + +3. **Expert Level (Hardware-Aware with Thread Primitives)** + - For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing). + - Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels. + - This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures. + +## Compilation Flow + +1. **Tile Program** + A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives. + +2. **Tile Program with Tile Library** + When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations. + +3. **Tile Program with Thread Primitives** + Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage. + +4. **IRModule** + After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details. + +5. **Source Code Generation (C/CUDA/HIP/LLVM/…)** + From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD). + +6. **Hardware-Specific Executable/Runtime** + Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures. + +## Tile-based Programming Model + +[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``, +illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining, +and operator calls to manage data movement and computation with fine-grained control. +In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling +leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization +and reduce latency. +Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang`` +allows developers to reason about performance-critical optimizations within a user-friendly programming model. + +```{figure} ../_static/img/MatmulExample.png +:align: center +:width: 100% +:alt: GEMM with Multi-Level Tiling on GPUs +:name: fig-overview-gemm + +Figure 2: Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``. +``` + +### Tile declarations + +At the heart of our approach is the notion of *tiles* as first-class objects in the programming model. A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block, or equivalent parallel unit. In the `Matmul` example, the `A` and `B` buffers are read in tiled chunks (determined by `block_M`, `block_N`, `block_K`) inside the kernel loop. With `T.Kernel`, `TileLang` defines the execution context, which includes the thread block index (`bx` and `by`) and the number of threads. These contexts can help compute the index for each thread block and make it easier for `TileLang` to automatically infer and optimize memory access and computation. Additionally, these contexts allow users to manually control the behavior of each independent thread within a thread block. + +### Explicit Hardware Memory Allocation + +A hallmark of `TileLang` is the ability to explicitly place these tile buffers in the hardware memory hierarchy. Rather than leaving it to a compiler's opaque optimization passes, `TileLang` exposes user-facing intrinsics that map directly to physical memory spaces or accelerator-specific constructs. In particular: + +- `T.alloc_shared`: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs. Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory and allows for efficient data sharing between threads in the same thread block. For example, in matrix multiplication, tiles of matrices can be loaded into shared memory to reduce global memory bandwidth demands and improve performance. + +- `T.alloc_fragment`: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs. By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized. Note that in this tile program, each tile allocates the same local buffers as shared memory, which might seem counterintuitive, as shared memory is generally faster but more abundant, whereas register file space is limited. This is because the allocation here refers to the register files for an entire thread block. `TileLang` uses a Layout Inference Pass during compilation to derive a Layout object `T.Fragment`, which determines how to allocate the corresponding register files for each thread. This process will be discussed in detail in subsequent sections. + +Data transfer between global memory and hardware-specific memory can be managed using `T.copy`. Furthermore, hardware-specific buffers can be initialized using `T.clear` or `T.fill`. For data assignments, operations can also be performed in parallel using `T.Parallel`, as demonstrated in Layout Inference Pass in the following sections. + +```{figure} ../_static/img/LayoutInference.png + :align: center + :width: 100% + :alt: GEMM with Multi-Level Tiling on GPUs + +``` diff --git a/tilelang/original/docs/get_started/targets.md b/tilelang/original/docs/get_started/targets.md new file mode 100644 index 0000000000000000000000000000000000000000..c2b3f2fb5ac7b119e1b084bb8694b99765eab40b --- /dev/null +++ b/tilelang/original/docs/get_started/targets.md @@ -0,0 +1,120 @@ +# Understanding Targets + +TileLang is built on top of TVM, which relies on **targets** to describe the device you want to compile for. +The target determines which code generator is used (CUDA, HIP, Metal, LLVM, …) and allows you to pass +device-specific options such as GPU architecture flags. This page summarises how to pick and customise a target +when compiling TileLang programs. + +## Common target strings + +TileLang ships with a small set of common targets; each accepts the full range of TVM options so you can fine-tune +the generated code. The most frequent choices are listed below: + +| Base name | Description | +| --------- | ----------- | +| `auto` | Detects CUDA → HIP → Metal in that order. Useful when running the same script across machines. | +| `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. | +| `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. | +| `metal` | Apple Silicon GPUs (arm64 Macs). | +| `llvm` | CPU execution; accepts the standard TVM LLVM switches. | +| `webgpu` | Browser / WebGPU runtimes. | +| `c` | Emit plain C source for inspection or custom toolchains. | + +To add options, append them after the base name, separated by spaces. For example: + +```python +target = "cuda -arch=sm_90" +kernel = tilelang.compile(func, target=target, execution_backend="cython") +# or +@tilelang.jit(target=target) +def compiled_kernel(*args): + return func(*args) +``` + +The same convention works for HIP or LLVM (e.g. `hip -mcpu=gfx940`, `llvm -mtriple=x86_64-linux-gnu`). + +### Advanced: Specify Exact Hardware + +When you already know the precise GPU model, you can encode it in the target string—either via `-arch=sm_XX` or by +using one of TVM’s pre-defined target tags such as `nvidia/nvidia-h100`. Supplying this detail is optional for +TileLang in general use, but it becomes valuable when the TVM cost model is enabled (e.g. during autotuning). The +cost model uses the extra attributes to make better scheduling predictions. If you skip this step (or do not use the +cost model), generic targets like `cuda` or `auto` are perfectly fine. + +All CUDA compute capabilities recognised by TVM’s target registry are listed below. Pick the one that matches your +GPU and append it to the target string or use the corresponding target tag—for example `nvidia/nvidia-a100`. + +| Architecture | GPUs (examples) | +| ------------ | ---------------- | +| `sm_20` | `nvidia/tesla-c2050`, `nvidia/tesla-c2070` | +| `sm_21` | `nvidia/nvs-5400m`, `nvidia/geforce-gt-520` | +| `sm_30` | `nvidia/quadro-k5000`, `nvidia/geforce-gtx-780m` | +| `sm_35` | `nvidia/tesla-k40`, `nvidia/quadro-k6000` | +| `sm_37` | `nvidia/tesla-k80` | +| `sm_50` | `nvidia/quadro-k2200`, `nvidia/geforce-gtx-950m` | +| `sm_52` | `nvidia/tesla-m40`, `nvidia/geforce-gtx-980` | +| `sm_53` | `nvidia/jetson-tx1`, `nvidia/jetson-nano` | +| `sm_60` | `nvidia/tesla-p100`, `nvidia/quadro-gp100` | +| `sm_61` | `nvidia/tesla-p4`, `nvidia/quadro-p6000`, `nvidia/geforce-gtx-1080` | +| `sm_62` | `nvidia/jetson-tx2` | +| `sm_70` | `nvidia/nvidia-v100`, `nvidia/quadro-gv100` | +| `sm_72` | `nvidia/jetson-agx-xavier` | +| `sm_75` | `nvidia/nvidia-t4`, `nvidia/quadro-rtx-8000`, `nvidia/geforce-rtx-2080` | +| `sm_80` | `nvidia/nvidia-a100`, `nvidia/nvidia-a30` | +| `sm_86` | `nvidia/nvidia-a40`, `nvidia/nvidia-a10`, `nvidia/geforce-rtx-3090` | +| `sm_87` | `nvidia/jetson-agx-orin-32gb`, `nvidia/jetson-agx-orin-64gb` | +| `sm_89` | `nvidia/geforce-rtx-4090` | +| `sm_90a` | `nvidia/nvidia-h100` (DPX profile) | +| `sm_100a` | `nvidia/nvidia-b100` | + +Refer to NVIDIA’s [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) page or the TVM source +(`3rdparty/tvm/src/target/tag.cc`) for the latest mapping between devices and compute capabilities. + +## Creating targets programmatically + +If you prefer working with TVM’s `Target` objects, TileLang exposes the helper +`tilelang.utils.target.determine_target` (returns a canonical target string by default, or the `Target` +object when `return_object=True`): + +```python +from tilelang.utils.target import determine_target + +tvm_target = determine_target("cuda -arch=sm_80", return_object=True) +kernel = tilelang.compile(func, target=tvm_target) +``` + +You can also build targets directly through TVM: + +```python +from tvm.target import Target + +target = Target("cuda", host="llvm") +target = target.with_host(Target("llvm -mcpu=skylake")) +``` + +TileLang accepts either `str` or `Target` inputs; internally they are normalised and cached using the canonical +string representation. **In user code we strongly recommend passing target strings rather than +`tvm.target.Target` instances—strings keep cache keys compact and deterministic across runs, whereas constructing +fresh `Target` objects may lead to slightly higher hashing overhead or inconsistent identity semantics.** + +## Discovering supported targets in code + +Looking for a quick reminder of the built-in base names and their descriptions? Use: + +```python +from tilelang.utils.target import describe_supported_targets + +for name, doc in describe_supported_targets().items(): + print(f"{name:>6}: {doc}") +``` + +This helper mirrors the table above and is safe to call at runtime (for example when validating CLI arguments). + +## Troubleshooting tips + +- If you see `Target cuda -arch=sm_80 is not supported`, double-check the spellings and that the option is valid for + TVM. Any invalid switch will surface as a target-construction error. +- Runtime errors such as “no kernel image is available” usually mean the `-arch` flag does not match the GPU you are + running on. Try dropping the flag or switching to the correct compute capability. +- When targeting multiple environments, use `auto` for convenience and override with an explicit string only when + you need architecture-specific tuning. diff --git a/tilelang/original/docs/index.md b/tilelang/original/docs/index.md new file mode 100644 index 0000000000000000000000000000000000000000..55804259a46a857b3589919e0536e565950ffa2d --- /dev/null +++ b/tilelang/original/docs/index.md @@ -0,0 +1,74 @@ +# 👋 Welcome to Tile Language + +[GitHub](https://github.com/tile-ai/tilelang) + +Tile Language (tile-lang) is a concise domain-specific language designed to streamline +the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). +By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, +tile-lang allows developers to focus on productivity without sacrificing the +low-level optimizations necessary for state-of-the-art performance. + +:::{toctree} +:maxdepth: 2 +:caption: GET STARTED + +get_started/Installation +get_started/overview +get_started/targets +::: + + +:::{toctree} +:maxdepth: 1 +:caption: TUTORIALS + +tutorials/debug_tools_for_tilelang +tutorials/auto_tuning +tutorials/logging +::: + +:::{toctree} +:maxdepth: 1 +:caption: PROGRAMMING GUIDES + +programming_guides/overview +programming_guides/language_basics +programming_guides/instructions +programming_guides/control_flow +programming_guides/autotuning +programming_guides/type_system +::: + +:::{toctree} +:maxdepth: 1 +:caption: DEEP LEARNING OPERATORS + +deeplearning_operators/elementwise +deeplearning_operators/gemv +deeplearning_operators/matmul +deeplearning_operators/matmul_sparse +deeplearning_operators/deepseek_mla +::: + +:::{toctree} +:maxdepth: 1 +:caption: COMPILER INTERNALS + +compiler_internals/letstmt_inline +compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks +::: + +:::{toctree} +:maxdepth: 1 +:caption: API Reference + +autoapi/tilelang/index +::: + +:::{toctree} +:maxdepth: 1 +:caption: Privacy + +privacy +::: diff --git a/tilelang/original/docs/make.bat b/tilelang/original/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..2034948c27be13f33e1b04492b15a94f0a4af284 --- /dev/null +++ b/tilelang/original/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/tilelang/original/docs/privacy.md b/tilelang/original/docs/privacy.md new file mode 100644 index 0000000000000000000000000000000000000000..3fb712bc28772d60b1d5eaf2597308716958b418 --- /dev/null +++ b/tilelang/original/docs/privacy.md @@ -0,0 +1,3 @@ +# Privacy + +All data stays in users' device and is not collected by the app. diff --git a/tilelang/original/docs/programming_guides/autotuning.md b/tilelang/original/docs/programming_guides/autotuning.md new file mode 100644 index 0000000000000000000000000000000000000000..66d46889fe88e4b1de874d552e2f1c922c534660 --- /dev/null +++ b/tilelang/original/docs/programming_guides/autotuning.md @@ -0,0 +1,308 @@ +# Autotuning + +TileLang includes a built‑in autotuner that searches configuration spaces +for the best performing kernel, compiles candidates in parallel, validates +correctness, benchmarks them, and caches the best result for reuse. + +This guide covers two workflows: +- Decorator‑based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit` +- Programmatic: `AutoTuner.from_kernel(...).set_*().run()` + +It also explains input tensor supply, validation, caching, and environment +variables that affect parallelism and cache behavior. + +## 1) Decorator‑based Autotune + +Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as +function arguments with defaults. The autotuner overrides these parameters with +values from your config space. + +```python +import tilelang +import tilelang.language as T + +def matmul_configs(M, N, K): + # Example space — tailor to your target + tiles = [64, 128] + stages = [2, 3] + threads = [128, 256] + return [ + dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH) + for BM in tiles + for BN in tiles + for BK in [32, 64] + for S in stages + for TH in threads + ] + +@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60) +@tilelang.jit(out_idx=[-1]) +def matmul(M: int, N: int, K: int, + block_M: int = 128, block_N: int = 128, block_K: int = 32, + threads: int = 128, num_stages: int = 3, + dtype: str = 'float16', accum_dtype: str = 'float32'): + + @T.prim_func + def kernel(A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_s = T.alloc_shared((block_M, block_K), dtype) + B_s = T.alloc_shared((block_K, block_N), dtype) + C_f = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_s) + T.copy(B[ko * block_K, bx * block_N], B_s) + T.gemm(A_s, B_s, C_f) + + T.copy(C_f, C[by * block_M, bx * block_N]) + + return kernel + +# Usage +# Provide inputs via context (recommended for reproducibility across configs) +import torch +M = N = K = 1024 +A = torch.randn(M, K, device='cuda', dtype=torch.float16) +B = torch.randn(K, N, device='cuda', dtype=torch.float16) +C = torch.empty(M, N, device='cuda', dtype=torch.float16) + +from tilelang.autotuner import set_autotune_inputs +with set_autotune_inputs(A, B, C): + tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel + tuned_kernel(A, B, C) # run best kernel +``` + +Notes +- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each + dict’s keys must match the tunable function arguments (e.g., `block_M`). +- The decorator returns a callable that runs autotune once per argument tuple + and caches the resulting best kernel in‑process. +- For explicit input control during tuning, wrap the call with + `set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used. + +## 2) Programmatic Autotune + +Use the `AutoTuner` class to manage configs and arguments more explicitly. + +```python +from tilelang.autotuner import AutoTuner + +kernel_factory = matmul # the function above (already @tilelang.jit) +tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K)) + +tuner.set_profile_args( + warmup=25, rep=100, timeout=60, + supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog + ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2), +) + +tuner.set_compile_args( + target='auto', # or 'cuda'/'hip'/'metal' + execution_backend='auto', # resolves per-target + out_idx=[-1], # which outputs to return if multiple + pass_configs={ # optional TVM passes/flags + # tilelang.PassConfigKey.EXAMPLE_KEY: value, + }, +) + +artifact = tuner.run() # compiles + runs + validates all configs +best_kernel = artifact.kernel # JITKernel +best_latency = artifact.latency +best_config = artifact.config + +# Reuse best kernel +best_kernel(A, B, C) +``` + +### Example Gallery (in repo) +- examples/gdn/example_chunk_delta_h.py:101 — uses `@autotune` to sweep configs +- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 — uses `@tilelang.autotune` +- examples/quickstart.py:84 — profiles a tuned kernel with `get_profiler` +- examples/hadamard_transform/example_hadamard.py:152 — profiler with custom warmup +- examples/dynamic_shape/example_dynamic.py:94 — profiler for dynamic shapes +- examples/gemm/example_gemm_persistent.py:135 — compare persistent vs non‑persistent + +Click any path to open the code and compare patterns. + +## Input Tensor Supply + +The tuner needs inputs to compile and benchmark kernels. Provide them in one of +three ways (priority order): + +1) Context manager (fixed inputs across configs) +```python +with set_autotune_inputs(A, B, C): + tuned = matmul(M, N, K) +``` + +2) Custom supplier program +```python +def supply_prog(signature): + # signature holds KernelParam objects describing shapes/dtypes + # Return a list of torch tensors matching the kernel’s arguments + return [A, B, C] + +tuner.set_profile_args(supply_prog=supply_prog) +``` + +3) Built‑in generators via `supply_type` +- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges) +- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One` + +Important +- Built‑in generators require static shapes; if your PrimFunc uses symbolic + dimensions (T.dyn), supply concrete inputs via (1) or (2). +- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support. + +## Correctness Checking and Tolerances + +Use one of the following validation methods: +- `ref_prog`: Provide a reference program that receives the same inputs and + checks results. You can return a boolean or raise on mismatch. +- `manual_check_prog`: A callable that inspects outputs and raises on mismatch. +- `skip_check=True`: Skip correctness checks (faster, use with caution). + +Control numeric drift via: +- `rtol` and `atol` (defaults 1e‑2) +- `max_mismatched_ratio` (default 1%) + +## Configuration Spaces and Best Practices + +What to tune +- Tile sizes: `block_M`, `block_N`, `block_K` +- Software pipelining: `num_stages` +- Threads per block: `threads` (or (x, y) tuple) +- Optional: dtype variants, epilogues, small scheduling knobs + +Tips +- Start from a working baseline. Tune a small, meaningful space first. +- Respect hardware limits (shared memory bytes, registers per thread/block, + max threads per block). Eliminate impossible configs up‑front. +- Keep block sizes multiples of vector widths and warp sizes when relevant. +- Use `set_autotune_inputs` to ensure each config is measured on identical data. +- Record your best configs and bake them as defaults when stable. + +## Parallel Compilation/Benchmarking and Timeouts + +The tuner compiles configurations in parallel using a thread pool and benchmarks +them with a per‑config timeout. On CUDA, each worker sets the current device to +avoid context issues. + +Notes +- `timeout` uses POSIX signals; on non‑Unix systems, it may not take effect. +- Logs are written to `autotuner.log` in the working directory. + +## Caching + +The autotuner caches best artifacts both in‑memory (per process) and on disk under +`$TILELANG_CACHE_DIR/autotuner`. The cache key includes: +- TileLang version, function source, closure free‑vars +- Config list, compile args, profile args + +Disk cache contents (per key) +- Best config and latency: `best_config.json`, `latency.json` +- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend) +- Function and params: `function.pkl`, `params.pkl` + +Control via env vars (tilelang.env) +- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`) +- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`) +- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1` +- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1` + +CPU worker control +- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9) +- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto) +- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited) + +Backend notes +- NVRTC backend persists `.cubin` and a Python launcher. +- Torch/DLPack backend may not save artifacts to disk; in this case, only + in‑memory caching applies and a warning is logged. + +## Alternative: Manual Sweeps with par_compile + +If you prefer manual control, use `JITImpl.par_compile` to compile a batch of +configs and drive your own benchmarking: + +```python +@tilelang.jit +def factory(M, N, K, block_M=128, block_N=128, block_K=32): + @T.prim_func + def k(A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16')): + ... + return k + +impl = factory # JITImpl +cfgs = [ + dict(block_M=64, block_N=128, block_K=32), + dict(block_M=128, block_N=128, block_K=64), +] +kernels = impl.par_compile(cfgs, num_workers=4) +# Now benchmark kernels[i](A, B, C) yourself +``` + +## Recording and Reusing Best Configs + +The programmatic path returns an `AutotuneResult` that can be saved and later +reloaded. This is useful for CI, multi‑host workflows, or shipping tuned configs. + +```python +artifact = tuner.run() # AutotuneResult + +# Save to disk +from pathlib import Path +save_dir = Path('out/best/matmul_1024') +artifact.save_to_disk(save_dir, verbose=True) + +# Reload later +from tilelang.autotuner.param import AutotuneResult, CompileArgs +restored = AutotuneResult.load_from_disk(save_dir, CompileArgs()) +best = restored.kernel +best(A, B, C) +``` + +Notes +- DLPack/Torch execution backend may not persist compiled binaries; in that + case, re‑compilation is needed on load or use a different backend. +- The directory contains human‑readable JSONs (best config/latency) and sources. + +## Advanced: Config Space Callables + +Derive config spaces from problem sizes to keep searches targeted and legal: + +```python +def matmul_configs(M, N, K): + large = min(M, N, K) >= 1024 + tiles = [128] if large else [64, 128] + for BM in tiles: + for BN in tiles: + for BK in [32, 64]: + for S in [2, 3]: + for TH in [128, 256]: + yield dict(block_M=BM, block_N=BN, block_K=BK, + num_stages=S, threads=TH) +``` + +## Device and Backend Selection + +Tune compile‑time options explicitly: +- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target) +- `execution_backend='auto'|'tvm_ffi'|'ctypes'|'cython'|'nvrtc'|'torch'` +- `pass_configs={...}` to toggle TileLang/TVM passes for experiments + +On CUDA with multiple GPUs, the tuner sets the current device per worker thread +to avoid context mixups. + +## Troubleshooting +- “No configurations to tune”: Ensure `configs` is a non‑empty list or callable. +- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that + your reference check isn’t the bottleneck. +- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom + `supply_prog`. +- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend. diff --git a/tilelang/original/docs/programming_guides/control_flow.md b/tilelang/original/docs/programming_guides/control_flow.md new file mode 100644 index 0000000000000000000000000000000000000000..158c51166e501a6628618e028ed0bbd904f7a47d --- /dev/null +++ b/tilelang/original/docs/programming_guides/control_flow.md @@ -0,0 +1,145 @@ +# Control Flow + +This guide covers the control‑flow primitives in TileLang and how they lower to +efficient GPU code. You will use these to structure loops, handle boundaries, +and express pipelined compute. + +## Overview +- Conditionals: `if` / `elif` / `else`, ternary (`x if c else y`) +- Loops: `T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined` +- While loops: `while` with a TIR condition +- Flow control: Python `break` / `continue` +- Safety: automatic OOB guards via the LegalizeSafeMemoryAccess pass + +The examples assume `import tilelang.language as T`. + +## Conditionals + +Standard Python `if`/`elif`/`else` is supported inside `@T.prim_func` kernels. +Conditions should be TIR expressions (e.g., `i < N`). Python plain booleans are +treated as compile‑time constants and will be folded. + +```python +for i in T.serial(N): + if i < N: # TIR condition + C[i] = A[i] + B[i] + else: + pass + +# Ternary +x = (A[i] if i < N else 0) +``` + +Short‑circuit boolean ops are supported. For multi‑dimensional bounds, use +`T.any_of` / `T.all_of` for clarity: + +```python +if T.all_of(i < M, j < N): + C[i, j] = A[i, j] + B[i, j] +``` + +Boundary handling note +- The LegalizeSafeMemoryAccess pass automatically inserts guards when an access + may be out‑of‑bounds, and elides them when proven safe. You can often omit + explicit `if` checks for simple edge handling, but keep them when you need + custom logic or clarity. + +## Loops + +### Serial + +`T.serial` creates a plain for‑loop. Common forms: + +```python +for i in T.serial(N): + ... # 0..N-1 + +for i in T.serial(0, N, 2): + ... # 0, 2, 4, ... +``` + +### Unroll + +`T.unroll` requests loop unrolling for small trip counts. + +```python +for k in T.unroll(K_TILE): + acc += a[k] * b[k] +``` + +Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are +available for expert tuning. + +### Parallel (elementwise) + +`T.Parallel(ext0, ext1, ...)` builds nested loops that map well to elementwise +operations. The body receives all indices in one `for` header: + +```python +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] +``` + +Optional: `coalesced_width=` can hint memory coalescing for the innermost loop. + +### Pipelined (software pipelining) + +`T.Pipelined(iters, num_stages=...)` overlaps producer/consumer stages (e.g., +Global→Shared copies with compute). This is the backbone of GEMM/attention +pipelines. + +```python +for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile + T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile + T.gemm(A_s, B_s, C_f) # stage: compute +``` + +### Persistent (advanced) + +`T.Persistent(domain, wave_size, index, group_size=...)` exposes persistent +thread‑block style looping. It is an advanced construct that TileLang lowers in +later passes and is typically used by specialized templates. + +## While Loops + +`while` is supported when the condition is a TIR expression. Avoid infinite +loops; TileLang will error if it detects a constant‑true condition. + +```python +i = 0 +while i < N: + ... + if done: + break + i += 1 +``` + +## Break and Continue + +Use Python `break`/`continue` to exit or skip within `T.serial`/`T.unroll`/ +`T.Parallel`/`while` loops. Keep the body clean after a `break`/`continue` for +readability; the compiler will ignore the dead path. + +## Putting It Together: Residual Tile Handling + +Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess, +the explicit guard can be omitted when you don’t need a custom edge path. + +```python +for i, j in T.Parallel(M, N): + gi = by * BM + i + gj = bx * BN + j + if T.all_of(gi < M, gj < N): # optional in many cases + C[gi, gj] = A[gi, gj] + B[gi, gj] +``` + +## Debugging Conditions + +Use `T.print` to inspect values under predicates. For buffers, TileLang prints +from a single thread to avoid duplicate outputs. + +```python +if i == 0: + T.print(C, msg='C tile:') +``` diff --git a/tilelang/original/docs/programming_guides/instructions.md b/tilelang/original/docs/programming_guides/instructions.md new file mode 100644 index 0000000000000000000000000000000000000000..84bd9217990003044a97e6c59007486cae64566f --- /dev/null +++ b/tilelang/original/docs/programming_guides/instructions.md @@ -0,0 +1,182 @@ +# Instructions + +This page summarizes the core TileLang “instructions” available at the DSL +level, how they map to hardware concepts, and how to use them correctly. + +## Quick Categories +- Data movement: `T.copy`, `T.c2d_im2col`, staging Global ↔ Shared ↔ Fragment +- Compute primitives: `T.gemm`/`T.gemm_sp`, elementwise math (`T.exp`, `T.max`), + reductions (`T.reduce_sum`, `T.cumsum`, warp reducers) +- Control helpers: `T.clear`/`T.fill`, `T.reshape`/`T.view` +- Diagnostics: `T.print`, `T.device_assert` +- Advanced: atomics, memory barriers, warp‑group ops + +## Data Movement + +Use `T.copy(src, dst, coalesced_width=None, disable_tma=False, eviction_policy=None)` +to move tiles between memory scopes. It accepts `tir.Buffer`, `BufferLoad`, or +`BufferRegion`; extents are inferred or broadcast when possible. + +```python +# Global → Shared tiles (extents inferred from dst) +T.copy(A[by * BM, ko * BK], A_s) +T.copy(B[ko * BK, bx * BN], B_s) + +# Fragment/Register → Global (store result) +T.copy(C_f, C[by * BM, bx * BN]) +``` + +Semantics +- Extents are deduced from arguments; missing sides broadcast to the other’s rank. +- Access patterns are legalized and coalesced during lowering. Explicit + vectorization is not required in HL mode. +- Safety: the LegalizeSafeMemoryAccess pass inserts boundary guards when an + access may be out‑of‑bounds and drops them when proven safe. + +Other helpers +- `T.c2d_im2col(img, col, ...)`: convenience for conv‑style transforms. + +## Compute Primitives + +GEMM and sparse GEMM +- `T.gemm(A_shared, B_shared, C_fragment)`: computes a tile GEMM using shared + inputs and a fragment accumulator; lowered to target‑specific tensor cores. +- `T.gemm_sp(...)`: 2:4 sparse tensor core variant (see examples and README). + +Reductions and scans +- `T.reduce_sum`, `T.reduce_max`, `T.reduce_min`, `T.cumsum`, plus warp + reducers (`T.warp_reduce_sum`, etc.). +- Allocate and initialize accumulators via `T.alloc_fragment` + `T.clear` or + `T.fill`. + +Elementwise math +- Most math ops mirror TVM TIR: `T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, + `T.sigmoid`, etc. Compose freely inside loops. + +Reshape/view (no copy) +- `T.reshape(buf, new_shape)` and `T.view(buf, shape=None, dtype=None)` create + new views that share storage, with shape/dtype checks enforced. + +## Synchronization (HL usage) + +In HL pipelines, you usually don’t need to write explicit barriers. Passes such +as PipelinePlanning/InjectSoftwarePipeline/InjectTmaBarrier orchestrate +producer/consumer ordering and thread synchronization behind the scenes. + +If you need debugging or explicit checks: +- `T.device_assert(cond, msg='')` emits device‑side asserts on CUDA targets. +- `T.print(obj, msg='...')` prints scalars or buffers safely from one thread. + +## Putting It Together: GEMM Tile + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # Global → Shared + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # compute into fragment + + T.copy(C_f, C[by * BM, bx * BN]) # store back +``` + +## Instruction Reference (Concise) + +Below is a concise list of TileLang instructions grouped by category. For full +signatures, behaviors, constraints, and examples, refer to API Reference +(`autoapi/tilelang/index`). + +Data movement +- `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment. +- `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv. + +Memory allocation and descriptors +- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer. +- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment. +- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem). +- `T.alloc_barrier(arrive_count)`: Shared barrier buffer. +- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+). +- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf. +- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator. + - `T.alloc_wgmma_desc(dtype='uint64')` + - `T.alloc_tcgen05_smem_desc(dtype='uint64')` + - `T.alloc_tcgen05_instr_desc(dtype='uint32')` +- `T.empty(shape, dtype='float32')`: Declare function output tensors. + +Compute primitives +- `T.gemm(A_s, B_s, C_f)`: Tile GEMM into fragment accumulator. +- `T.gemm_sp(...)`: Sparse (2:4) tensor core GEMM. +- Reductions: `T.reduce_sum/max/min/abssum/absmax`, bitwise `and/or/xor`. +- Scans: `T.cumsum`, finalize: `T.finalize_reducer`. +- Warp reducers: `T.warp_reduce_sum/max/min/bitand/bitor`. +- Elementwise math: TIR ops (`T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, ...). +- Fast math: `T.__log/__log2/__log10/__exp/__exp2/__exp10/__sin/__cos/__tan`. +- IEEE math: `T.ieee_add/sub/mul/fmaf` (configurable rounding). +- Helpers: `T.clear(buf)`, `T.fill(buf, value)`. +- Views: `T.reshape(buf, shape)`, `T.view(buf, shape=None, dtype=None)`. + +Diagnostics +- `T.print(obj, msg='')`: Print scalar/buffer from one thread. +- `T.device_assert(cond, msg='')`: Device-side assert (CUDA). + +Logical helpers +- `T.any_of(a, b, ...)`, `T.all_of(a, b, ...)`: Multi-term predicates. + +Annotation helpers +- `T.use_swizzle(panel_size=..., enable=True)`: Rasterization hint. +- `T.annotate_layout({...})`: Attach explicit layouts to buffers. +- `T.annotate_safe_value(var, ...)`: Safety/const hints. +- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint. + +Atomics +- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`. +- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`. +- `T.atomic_max(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_min(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_load(dst)`, `T.atomic_store(dst, value)`. + +Custom intrinsics +- `T.dp4a(A, B, C)`: 4‑element dot‑product accumulate. +- `T.clamp(x, lo, hi)`: Clamp to [lo, hi]. +- `T.loop_break()`: Break from current loop via intrinsic. + +Barriers, TMA, warp‑group +- Barriers: `T.create_list_of_mbarrier(...)`, `T.get_mbarrier(i)`. +- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`. +- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`. +- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`, + `T.tma_store_arrive(...)`, `T.tma_store_wait(...)`. +- Proxy/fences: `T.fence_proxy_async(...)`, `T.warpgroup_fence_operand(...)`. +- Warp‑group: `T.warpgroup_arrive()`, `T.warpgroup_commit_batch()`, + `T.warpgroup_wait(num_mma)`, `T.wait_wgmma(id)`. + +Lane/warp index +- `T.get_lane_idx(warp_size=None)`: Lane id in warp. +- `T.get_warp_idx_sync(warp_size=None)`: Canonical warp id (sync). +- `T.get_warp_idx(warp_size=None)`: Canonical warp id (no sync). +- `T.get_warp_group_idx(warp_size=None, warps_per_group=None)`: Group id. + +Register control +- `T.set_max_nreg(reg_count, is_inc)`, `T.inc_max_nreg(n)`, `T.dec_max_nreg(n)`. +- `T.annotate_producer_reg_dealloc(n=24)`, `T.annotate_consumer_reg_alloc(n=240)`. +- `T.no_set_max_nreg()`, `T.disable_warp_group_reg_alloc()`. + + + +## Notes on Dtypes + +Dtypes accept three equivalent forms: +- String: `'float32'` +- TileLang dtype: `T.float32` +- Framework dtype: `torch.float32` +All are normalized internally. See Type System for details. diff --git a/tilelang/original/docs/programming_guides/language_basics.md b/tilelang/original/docs/programming_guides/language_basics.md new file mode 100644 index 0000000000000000000000000000000000000000..1152680c970460f3c91f871d2b0e82ec73034918 --- /dev/null +++ b/tilelang/original/docs/programming_guides/language_basics.md @@ -0,0 +1,234 @@ +# Language Basics + +This page introduces the core TileLang (tile‑lang) DSL that you’ll use to write +high‑performance kernels. It focuses on how to define a kernel, express +iteration, move data across memory scopes, and run it with JIT. + +The examples use the conventional aliases: + +```python +import tilelang +import tilelang.language as T +from tilelang import jit +``` + +## 1. Defining a Kernel with `@T.prim_func` + +TileLang kernels are TIR (TVM IR) functions produced by the `@T.prim_func` +decorator. Arguments are annotated with shapes and dtypes via `T.Tensor` or +`T.Buffer`. + +Note on dtypes +- You can pass dtypes as a string (e.g., 'float32'), a TileLang dtype (e.g., `T.float32`), + or a framework dtype (e.g., `torch.float32`). TileLang normalizes all of these. + See Type System for details. + +```python +@T.prim_func +def add_kernel( + A: T.Tensor((N,), dtype), # dtype could be 'float32' | T.float32 | torch.float32 + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), +): + ... # kernel body +``` + +- Shapes may be concrete integers or symbolic. For symbolic, you can pass + Python ints through the outer `@jit` wrapper (shown below), or annotate with + `T.dyn` when you want a named symbolic dimension. + +```python +# Named symbolic dimension (optional) +K = T.dyn['K'] +@T.prim_func +def uses_dyn(A: T.Tensor((K,), 'float32')): + ... +``` + +### Dynamic symbolic dimensions: two ways + +TileLang supports two complementary ways to introduce symbolic (dynamic) dims: + +- Type-level annotations via `T.dyn[...]` (recommended for function signatures) + - Use in `T.Tensor((T.dyn['K'], ...), dtype)` or bind once then reuse (as above). + - Inside the kernel body, prefer reading from the buffer’s shape, e.g. `M = A.shape[0]`. + +- Term-level variables via `T.dynamic(name, dtype)` + - Creates a TIR `tir.Var` you can use directly in expressions/loops. + - Handy when you need to reference the dimension symbol in the body. + +```python +# 1) Annotation-only symbol; read the bound size via shape +K = T.dyn['K'] # dtype defaults to int32 +@T.prim_func +def foo(A: T.Tensor((K,), 'float32')): + N = A.shape[0] + for i in T.serial(N): + ... + +# 2) Explicit Var symbol usable in the body +K = T.dynamic('K', 'int32') # or T.dynamic('K') defaults to int32 +@T.prim_func +def bar(A: T.Tensor((K,), 'float32')): + for i in T.serial(K): + ... +``` + +Notes +- `T.symbolic(name, dtype)` is a deprecated alias of `T.dynamic`; prefer `T.dynamic`. +- Under `@jit`, concrete sizes come from the actual tensor arguments at the first call. +- Symbols in annotations do not need to be separate kernel arguments; TileLang binds them from argument shapes. + +## 2. Launching Work with `T.Kernel` + +`with T.Kernel(...)` declares a launch context and creates block/thread +bindings. For GPU backends, specify a grid and threads per block. + +```python +with T.Kernel(grid_x, grid_y, threads=128) as (bx, by): + ... # bx/by are blockIdx.x/y +``` + +You rarely need raw thread indices; most kernels use structured loops +(`T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`) inside a `T.Kernel`. + +## 3. Loops and Control Flow + +Core loop constructs map to familiar hardware patterns: + +- `T.serial(start, stop[, step])`: plain for‑loop +- `T.unroll(start, stop[, step])`: unrolled loop +- `T.Parallel(ext0, ext1, ...)`: nested parallel loops (elementwise‑friendly) +- `T.Pipelined(iters, num_stages=N)`: software pipelining for producer/consumer + +```python +for i in T.serial(N): + ... + +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] + +for k in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + # overlap copy/compute across stages + ... +``` + +Conditionals use standard Python `if`/`else`. Guard edges with predicates when +tile sizes do not divide problem sizes evenly. + +## 4. Memory Scopes and Allocation + +TileLang exposes key software‑managed scopes: + +- Global: device memory (default for `T.Tensor` arguments) +- Shared: on‑chip, block‑visible (`T.alloc_shared(shape, dtype)`) +- Fragment and scalars: per‑thread fragments and scalar vars but in Shared View + (`T.alloc_fragment`, `T.alloc_var`) + +```python +A_shared = T.alloc_shared((BM, BK), 'float16') +B_shared = T.alloc_shared((BK, BN), 'float16') +C_local = T.alloc_fragment((BM, BN), 'float32') +T.clear(C_local) # zero accumulators +``` + +## 5. Moving Data: `T.copy` + +Use `T.copy(src, dst)` to move tiles between scopes. It accepts buffers, +buffer regions, or buffer loads; extents are inferred or can be broadcast. + +```python +# Global -> Shared (tile copy), extents inferred from dst +T.copy(A[by * BM, ko * BK], A_shared) +T.copy(B[ko * BK, bx * BN], B_shared) + +# Fragment -> Global (store back) +T.copy(C_local, C[by * BM, bx * BN]) +``` + +`T.copy` performs coalescing and scope‑specific lowering during compilation. + +## 6. A Minimal End‑to‑End Example (Vector Add) + +```python +import tilelang +import tilelang.language as T +from tilelang import jit + +@jit # infers target from tensors at first call +def add(N: int, block: int = 256, dtype: str = 'float32'): + + @T.prim_func + def add_kernel( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block), threads=block) as bx: + for i in T.Parallel(block): + gi = bx * block + i + # Optional — LegalizeSafeMemoryAccess inserts a guard when an access may be OOB + C[gi] = A[gi] + B[gi] + + return add_kernel + +# Host side (PyTorch shown; NumPy/DLPack also supported) +import torch +N = 1 << 20 +A = torch.randn(N, device='cuda', dtype=torch.float32) +B = torch.randn(N, device='cuda', dtype=torch.float32) +C = torch.empty(N, device='cuda', dtype=torch.float32) + +kernel = add(N) +kernel(A, B, C) # runs on GPU +torch.testing.assert_close(C, A + B) +``` + +Notes +- The `@jit` wrapper returns a callable kernel after the first compilation. +- You can pass compile‑time tunables (tile sizes, dtypes) through the outer + Python function and bake them into the generated TIR. + +## 7. Tiled GEMM Skeleton + +Below is a minimal pattern for a tiled GEMM using shared memory staging and a +fragment accumulator. It mirrors the quickstart style found in the repository. + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # lowered to tensor‑core/ISA specific kernels + + T.copy(C_f, C[by * BM, bx * BN]) +``` + +## 8. Debugging and Printing + +Use `T.print` inside a kernel for quick introspection. TileLang emits printing +from a single thread for shared/fragment scopes to avoid floods. + +```python +T.print(C_f, msg='accumulator:') +T.print(A_s, msg='A tile:') +T.print(C[0], msg='C[0] = ') +``` + +## 9. Where to Go Next + +- Control flow details: see Programming Guides → Control Flow +- Memory topics: see Programming Guides → (removed cache/layout); basics are covered inline +- Autotuning tile sizes and mappings: Programming Guides → Autotuning +- Operator examples (GEMM, GEMV, attention): see Deep Learning Operators diff --git a/tilelang/original/docs/programming_guides/overview.md b/tilelang/original/docs/programming_guides/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..64b6d20390bd7350001a9882abda10906cea2874 --- /dev/null +++ b/tilelang/original/docs/programming_guides/overview.md @@ -0,0 +1,27 @@ +# Programming Guides Overview + +This section provides a practical guide to writing high‑performance kernels with Tile Language (tile‑lang). +It mirrors the structure of a similar guide in another project and adapts it to tile‑lang concepts and APIs. + +- Audience: Developers implementing custom GPU/CPU kernels with tile‑lang +- Prereqs: Basic Python, NumPy/Tensor concepts, and familiarity with GPU programming notions +- Scope: Language basics, control flow, instructions, autotuning, and type system + +## What You’ll Learn +- How to structure kernels with TileLang’s core DSL constructs +- How to move data across global/shared/fragment and pipeline compute +- How to apply autotuning to tile sizes and schedules +- How to specify and work with dtypes in kernels + +## Suggested Reading Order +1. Language Basics +2. Control Flow +3. Instructions +4. Autotuning +5. Type System + +## Related Docs +- Tutorials: see existing guides in `tutorials/` +- Operators: examples in `deeplearning_operators/` + +> NOTE: This is a draft scaffold. Fill in code snippets and benchmarks as APIs evolve. diff --git a/tilelang/original/docs/programming_guides/type_system.md b/tilelang/original/docs/programming_guides/type_system.md new file mode 100644 index 0000000000000000000000000000000000000000..32b9274d7c48bcd3b7ebfcdc4b35a56862b86f10 --- /dev/null +++ b/tilelang/original/docs/programming_guides/type_system.md @@ -0,0 +1,42 @@ +# Type System + +This page lists the data types supported by TileLang and how to specify them in +kernels. For full details and the authoritative list, see the API Reference +(`autoapi/tilelang/index`) and `tilelang.language.v2.dtypes`. + +How to specify dtypes +- Use any of the following forms; TileLang normalizes them internally: + - String: `'float32'`, `'int8'`, `'bfloat16'`, ... + - TileLang dtype object: `T.float32`, `T.int8`, `T.bfloat16`, ... + - Framework dtype: `torch.float32`, `torch.int8`, `torch.bfloat16`, ... + +Common scalar types +- Boolean: `bool` +- Signed integers: `int8`, `int16`, `int32`, `int64` +- Unsigned integers: `uint8`, `uint16`, `uint32`, `uint64` +- Floating‑point: `float16` (half), `bfloat16`, `float32`, `float64` + +Float8 and low‑precision families +- Float8: `float8_e3m4`, `float8_e4m3`, `float8_e4m3b11fnuz`, `float8_e4m3fn`, + `float8_e4m3fnuz`, `float8_e5m2`, `float8_e5m2fnuz`, `float8_e8m0fnu` +- Float6: `float6_e2m3fn`, `float6_e3m2fn` +- Float4: `float4_e2m1fn` + +Vectorized element types (SIMD packs) +- For many base types, vector‑packed variants are available by lane count: + `x2`, `x4`, `x8`, `x16`, `x32`, `x64`. +- Examples: + - Integers: `int8x2`, `int8x4`, ..., `int32x2`, `int32x4`, ... + - Unsigned: `uint8x2`, `uint8x4`, ... + - Floats: `float16x2`, `float16x4`, `float32x2`, `float32x4`, ... + - Float8/6/4 families also provide `x2/x4/x8/x16/x32/x64` where applicable, + e.g., `float8_e4m3x2`, `float8_e4m3x4`, `float6_e2m3fnx8`, `float4_e2m1fnx16`. + +Notes +- Availability of certain low‑precision formats (float8/6/4) depends on target + architecture and backend support. +- Choose accumulation dtypes explicitly for mixed‑precision compute (e.g., + GEMM with `float16` inputs and `float32` accumulators). +- The complete, up‑to‑date list is exposed in + `tilelang.language.v2.dtypes` and rendered in the API Reference. + diff --git a/tilelang/original/docs/requirements.txt b/tilelang/original/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..63b64db21c52b066942f05082c71ae78c10fd2b3 --- /dev/null +++ b/tilelang/original/docs/requirements.txt @@ -0,0 +1,13 @@ +fastapi +pydantic +sphinx +sphinx-reredirects +sphinx-tabs +sphinx-toolbox +sphinxcontrib-napoleon +sphinxcontrib_httpdomain +furo +uvicorn +myst-parser +sphinx-autoapi == 3.6.0 +astroid < 4 diff --git a/tilelang/original/docs/spelling_wordlist.txt b/tilelang/original/docs/spelling_wordlist.txt new file mode 100644 index 0000000000000000000000000000000000000000..e859d0e7b109baafea24b390c4c0331393950123 --- /dev/null +++ b/tilelang/original/docs/spelling_wordlist.txt @@ -0,0 +1,8 @@ +cancelled +hsa +ist +LOD +nd +NotIn +offen +te diff --git a/tilelang/original/docs/tutorials/auto_tuning.md b/tilelang/original/docs/tutorials/auto_tuning.md new file mode 100644 index 0000000000000000000000000000000000000000..3f3cad832232898017a439deff4fb84cc499a412 --- /dev/null +++ b/tilelang/original/docs/tutorials/auto_tuning.md @@ -0,0 +1,148 @@ +Auto-Tuning Techniques for Performance Optimization +=================================================== +
+Author: yyttt6 +
+ +## Overview + +Auto-tuning a Tile Language program involves three main steps: + +1. Implement the target program using Tile Language with reserved optimization parameters +2. ​Provide candidate configurations through manual search or [auto-generation using Carver](#using-carver-to-auto-generate-candidate-configurations) +3. Parallel compile and benchmark candidate configurations to identify the best performance + +## Matrix Multiplication Example + +The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. + +### Step 1: Implement with Reserved Parameters +Users can implement matrix multiplication in Tile Language while reserving parameters for optimization: +```python +# Reserved parameters for optimization +def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, +): + dtype = "float16" + accum_dtype = "float" + + # Matrix multiplication implementation + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((M, N), dtype), + ): + # ...existing code... + + return main +``` +### Step 2: Generate Candidate Configurations +Manually define configurations or use combinatorial generation: +```python +configs = [ + { + "block_M": 128, + "block_N": 128, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "enable_rasteration": True + }, + { + "block_M": 32, + "block_N": 32, + "block_K": 32, + "num_stages": 0, + "thread_num": 32, + "enable_rasteration": False + }, + # ...additional configurations... +] +``` +It can also be given by combinatorial traversal of different parameters +```python +import itertools + +block_M = [64, 128, 256] +block_N = [64, 128, 256] +block_K = [32, 64] +num_stages = [0, 1, 2, 3] +thread_num = [128, 256] +enable_rasterization = [True, False] +_configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + )) + +configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5] + } for c in _configs +] +``` +### Step 3: Compile and Benchmark +Configure JIT compilation and benchmarking settings: +```python +autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + out_idx=[-1], + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + target="auto", + ) +result = autotuner.run(warmup=3, rep=20) +out_c = result.kernel(a, b) +``` +The result object contains optimized kernel implementation which can be used by users directly + +## Using Carver to Auto-Generate Candidate Configurations + +Carver is a lightweight framework for generating and ranking tile configurations (also known as tiling strategies, blocking schemes, or scheduling hints) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels. + +or common operators, Carver provides pre-built templates (e.g., `MatmulTemplate`): + +```python +# Configure Matmul template +arch = CUDA("cuda") +carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float", +).with_arch(arch) + +# Generate top-k optimization hints (topk=10 recommended) +roller_hints = carve_template.recommend_hints(topk=10) + +# Configure candidate parameters +for hint in roller_hints: + + # ...existing code... + + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage + config["thread_num"] = block_rows * block_cols * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + +``` \ No newline at end of file diff --git a/tilelang/original/docs/tutorials/debug_tools_for_tilelang.md b/tilelang/original/docs/tutorials/debug_tools_for_tilelang.md new file mode 100644 index 0000000000000000000000000000000000000000..f8dfaab826e0c2896a922a915cce44b2f9ab8153 --- /dev/null +++ b/tilelang/original/docs/tutorials/debug_tools_for_tilelang.md @@ -0,0 +1,204 @@ +# Debugging Tile Language Programs + +
+Author: Lei Wang +
+ +## Overview + +A Tile Language program (hereafter referred to as a *program*) is transformed into a hardware-executable file through several stages: + +1. The user writes a Tile Language program. +2. The program undergoes multiple *Passes* for transformation and optimization (the *lower* stage, see `tilelang/engine/lower.py`), finally producing an intermediate representation (e.g., LLVM or C for CPU, CUDA for NVIDIA GPUs, etc.). +3. The generated code is compiled by the respective compiler (e.g., nvcc) into a hardware-executable file. + + +```{figure} ../_static/img/overview.png +:width: 300 +:alt: Overview of the compilation process +:align: center + +``` + +During this process, users may encounter roughly three categories of issues: + +* **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process). +* **Correctness issues**: The resulting executable runs, but produces incorrect results. +* **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits. + +This tutorial focuses on the first two issues—how to debug generation and correctness problems. Performance tuning often requires using vendor-provided profiling tools (e.g., **Nsight Compute**, **rocProf**, etc.) for further hardware-level analysis, which we will address in future materials. + +Below, we take matrix multiplication (GEMM) as an example to demonstrate how to write and debug a Tile Language program. + +## Matrix Multiplication Example + +In **Tile Language**, you can use the **Tile Library** to implement matrix multiplication. Here's a complete example: + +```python +import tilelang +import tilelang.language as T + +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + # ...existing code... + +# 1. Define the kernel (matmul) with the desired dimensions +func = matmul(1024, 1024, 1024, 128, 128, 32) + +# 2. Compile the kernel into a torch function +# ...existing code... +``` + +## Debugging Generation Issues + +TileLang essentially performs *progressive lowering*. For example, a `T.copy` may first be expanded into `T.Parallel` (see the pass `LowerTileOP`), which is then expanded again, eventually resulting in lower-level statements that can be translated to CUDA C code. + + +```{figure} ../_static/img/ir_transform_diagram.png +:width: 400 +:alt: IR transformation diagram +:align: center + +``` + +When the code fails to generate (for instance, a compilation error occurs), you do **not** necessarily need to jump directly into C++ passes to debug. Instead, you can first inspect the intermediate representations (IR) in Python by printing them. + +For example, consider a case where a simple `T.copy` in 1D causes the lowering process to fail. The snippet below illustrates a simplified version of the problem (based on community Issue #35): + +```python +@T.prim_func +def main(Q: T.Tensor(shape_q, dtype)): + # ...existing code... +``` + +The TileLang lower process might yield an error such as: + +```text +File "/root/TileLang/src/target/codegen_cuda.cc", line 1257 +ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed. +``` + +This indicates that somewhere during code generation, an unsupported vectorization pattern was introduced (a ramp of 8 lanes). Before diving into the underlying C++ code, it is helpful to print the IR right before code generation. For instance: + +```python +device_mod = tir.transform.Filter(is_device_call)(mod) +# ...existing code... +``` + +## Debugging Correctness Issues + +Sometimes, the kernel compiles and runs but produces incorrect results. In such cases, there are two main strategies to help debug: + +1. **Use post-processing callbacks to inspect or modify the generated CUDA code.** +2. **Use the built-in `T.print` debugging primitive to inspect values at runtime.** + +### Post-Processing Callbacks for Generated Source + +After code generation (in the codegen pass), TileLang calls a callback function (if registered) to allow post-processing of the generated source code. In `src/target/rt_mod_cuda.cc`: + +```cpp +std::string code = cg.Finish(); +if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) { + code = (*f)(code, target).operator std::string(); +} +``` + +Hence, by registering a Python function named `tilelang_callback_cuda_postproc`, you can intercept the final CUDA code string. For example: + +```python +import tilelang +import tilelang.language as T +from tilelang import tvm +from tilelang.engine.callback import register_cuda_postproc_callback + +@register_cuda_postproc_callback +def tilelang_callback_cuda_postproc(code, _): + print(code) # print the final CUDA code + code = "// modified by tilelang_callback_cuda_postproc\n" + code + return code + +kernel = tilelang.compile(matmul, target="cuda") +kernel_source = kernel.get_kernel_source() +print(kernel_source) +''' +// modified by tilelang_callback_cuda_postproc +#include "cuda_runtime.h" +... +''' +``` + +### Runtime Debug Prints with `T.print` + +TileLang provides a built-in debugging primitive called `T.print` for printing within kernels. Be mindful of concurrency and thread synchronization when using it in GPU code. Below are some examples showing how to print buffers, variables, and other data inside TileLang programs. + +1. **Printing an Entire Buffer** + +```python +def debug_print_buffer(M=16, N=16): + # ...existing code... +``` + +2. **Conditional Printing** + +```python +def debug_print_buffer_conditional(M=16, N=16): + # ...existing code... +``` + +3. **Printing Thread Indices or Scalar Values** + +```python +def debug_print_value_conditional(M=16, N=16): + # ...existing code... +``` + +4. **Printing Fragment (Register File) Contents** + +```python +def debug_print_register_files(M=16, N=16): + # ...existing code... +``` + +5. **Adding a Message Prefix** + +```python +def debug_print_msg(M=16, N=16): + # ...existing code... +``` + +The output messages will include something like: + +```text +msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0 +``` + +### Visual Layout Inference For TileLang + The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations. + +When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates: +1. **Textual output**: A human-readable description of the layout mapping +2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping + +The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation. + +When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats: +- "txt": Text output only (same as default) +- "all": Generates all formats (TXT, PDF, PNG, SVG) +- "png": Generate PNG format only +- "pdf": Generate PDF format only +- "svg": Generate SVG format only +- "txt,svg": Generate multiple formats (comma-separated) in addition to text output + +The output messages of "txt" will include something like: +``` +C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] +``` + + +## Conclusion + +By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs. + +For advanced performance tuning (e.g., analyzing memory bandwidth or occupancy), more specialized profiling tools such as **Nsight Compute**, **rocProf**, or vendor-specific profilers may be required. Those aspects will be covered in future documents. diff --git a/tilelang/original/docs/tutorials/logging.md b/tilelang/original/docs/tutorials/logging.md new file mode 100644 index 0000000000000000000000000000000000000000..5caf432801cd7d0dfa1b5dfb2eef5c9e824e937c --- /dev/null +++ b/tilelang/original/docs/tutorials/logging.md @@ -0,0 +1,118 @@ +Logging in Tilelang/TVM +=================================================== +
+Author: SiriusNEO +
+ +## TVM Logging Overview + +Tilelang currently utilizes the logging system from TVM. The implementation can be found in: + +- [include/tvm/runtime/logging.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h): Macro definitions +- [src/runtime/logging.cc](https://github.com/apache/tvm/blob/main/src/runtime/logging.cc): Logging logic implementation + +The design style is inspired by [Google's glog](https://google.github.io/glog/stable/). + +## Logging Categories + +There are three primary macro types: + +```c++ +LOG(INFO) << "aaa"; +DLOG(INFO) << "aaa"; +VLOG(1) << "aaa"; +``` + +- **LOG**: Standard logging preserved in code for displaying necessary information at different levels during runtime. Most Tilelang C++ error reporting is implemented via `LOG(FATAL) << "error msg"`. +- **DLOG**: Debug logging for developer debugging output. DLOG is controlled at build time by the TVM_LOG_DEBUG environment variable and is **eliminated in Release builds through dead code elimination**. + - The key difference between LOG(DEBUG) and DLOG is this build-time elimination. We recommend using DLOG over LOG(DEBUG), as the latter has overlapping functionality and gets compiled into the release runtime. +- **VLOG**: [Verbose logging](https://google.github.io/glog/stable/logging/#verbose-logging), primarily for debugging. Its main feature is customizable verbosity levels. For example, VLOG(n) where n can be 1, 2, 3, 4, 5, or 6, enabling complex tracing requirements. In contrast, LOG and DLOG typically use predefined verbose levels like INFO and DEBUG. + - In practical Tilelang development, VLOG is used less frequently. + - TVM's VLOG is implemented using DLOG, thus inheriting DLOG's characteristics. + +Additional useful macros include various **CHECK** variants: + +```c++ +CHECK(cond) << "error msg"; +DCHECK(cond) << "error msg"; +ICHECK(cond) << "error msg"; +``` + +The implementation routes errors to LogFatal: + +```c++ +#define CHECK(x) \ + if (!(x)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check failed: (" #x << ") is false: " +``` +- **DCHECK**: Debug mode CHECK, only compiled in debug builds +- **ICHECK**: Internal Check that should exist in Release builds. When ICHECK fails, the entire system should report an error. + +## Logging Verbose Levels + +TVM defines 5 levels for LOG and DLOG (adding DEBUG compared to glog): + +```c++ +#define TVM_LOG_LEVEL_DEBUG 0 +#define TVM_LOG_LEVEL_INFO 1 +#define TVM_LOG_LEVEL_WARNING 2 +#define TVM_LOG_LEVEL_ERROR 3 +#define TVM_LOG_LEVEL_FATAL 4 +``` + +## Using Logging in TileLang Development + +### Guidelines + +For temporary debugging output in your code, there are no restrictions (you can even use std::cout). Just remember to remove it before submitting a PR. + +For meaningful logging that should remain in the Tilelang codebase: + +- Critical correctness checks: Use ICHECK with sufficient error messages to facilitate debugging when issues arise. +- Complex Pass debugging: For passes requiring intermediate output that may need future review (e.g., LayoutInference), use DLOG. +- General INFO/WARNING messages: Use standard LOG. + +### Enabling Log Output in Tilelang + +To specify current log level at runtime, we need to set the environment variable `TVM_LOG_LEVEL`. An example usage is: + +```c++ +TVM_LOG_DEBUG=1 python3 code.py +``` + +which enables all DEBUG/INFO (level <= 1) logs for all files. + +#### Detailed Rules for TVM_LOG_DEBUG Specification + +The parsing logic is in `logging.cc`. Reference: [HyperAI Zhihu Article](https://zhuanlan.zhihu.com/p/1933106843468665163). + +Launch Python with `TVM_LOG_DEBUG=`, where `` is a comma-separated list of level assignments in the form `=`. Important notes: + +- The special filename DEFAULT sets the LOG level for all files. +- `` can be set to -1 to disable LOG for that file. +- `` is the C++ source filename (e.g., .cc, not .h) relative to the `src/` directory in the TVM repository. The `src/` prefix is optional when specifying file paths. + +### Enabling Debug Mode + +To enable DLOG/DCHECK, developers need to first build Tilelang in Debug mode: + +```bash +cmake .. -DCMAKE_BUILD_TYPE=Debug -DUSE_CUDA=ON +``` + +Tilelang's CMake logic automatically adds the `TVM_LOG_DEBUG` macro, compiling all DLOG statements: + +```cmake +target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") +``` + +Then you also need to specify the runtime environment variables. For example, to use `DLOG(INFO) << "xxx"` for debugging, run your code with INFO level (1): `TVM_LOG_DEBUG=1`. + +:::{note} + **Important**: There are two TVM_LOG_DEBUG variables. (1) Compile-time macro: Determines whether debug content (like DLOG) is compiled into the .so file. Referenced in C++ source via #ifdef TVM_LOG_DEBUG. This is automatically enabled when using Debug build mode in CMake. (2) Runtime environment variable: Controls logging level at runtime. TVM provides a specification for this variable, allowing control over per-file logging levels. + + These two should ideally have different names, but TVM uses the same name for both, which can cause confusion. +::: + + diff --git a/tilelang/original/examples/amd/example_amd_flash_attn_bwd.py b/tilelang/original/examples/amd/example_amd_flash_attn_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..788aec367c45a56bda3bd001c3c8b5e6c6ebfe45 --- /dev/null +++ b/tilelang/original/examples/amd/example_amd_flash_attn_bwd.py @@ -0,0 +1,590 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.tileop.base import GemmWarpPolicy +import itertools +import argparse +from functools import partial +import numpy as np +import time + + +def ref_program(Q, K, V, is_causal, groups=1): + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + dim = Q.size(-1) + K_ref = K.repeat_interleave(groups, dim=2) + V_ref = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref) + lse = torch.logsumexp(scores, dim=-1).float() + return output, lse + + +def get_fwd_configs(): + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + threads = [128, 256, 512] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + enable_rasterization = [True] + k_pack = [2] + panel_size = [7, 8, 9, 10] + qk_coalesced_width = [8] + v_coalesced_width = [4] + + valid_configs = [] + + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) + return valid_configs + + +@tilelang.autotune(configs=get_fwd_configs(), cache_input_tensors=True) +@tilelang.jit(out_idx=[3, 4]) +def fast_flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_split_q: int, + threads: int, + num_stages: int, + enable_rasterization: bool, + k_pack: int, + panel_size: int, + qk_coalesced_width: int, + v_coalesced_width: int, +): + scale = (1.0 / dim) ** 0.5 + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + vec_size = qk_coalesced_width + v_vec_size = v_coalesced_width + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + LSE: T.Tensor([batch, heads, seq_len], accum_dtype), + ): + with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): + T.use_swizzle(panel_size, enable=enable_rasterization) + + bz = byz_combined // heads + by = byz_combined % heads + + num_q_blocks = T.ceildiv(seq_len, block_M) + + bx_loop_var = T.alloc_var(T.int32) + bx_loop_var = b_split + + with T.While(bx_loop_var < num_q_blocks): + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + m_i = T.alloc_fragment([block_M], accum_dtype) + l_i = T.alloc_fragment([block_M], accum_dtype) + + T.fill(acc_o, 0) + T.fill(m_i, -T.infinity(accum_dtype)) + T.fill(l_i, 0) + + current_bx = bx_loop_var + q_block_offset = current_bx * block_M + + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + m_prev = T.alloc_fragment([block_M], accum_dtype) + scale_factor = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) + + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + + row_sum = T.alloc_fragment([block_M], accum_dtype) + + for k in T.Pipelined(loop_end_k, num_stages=num_stages): + kv_idx = k * block_N + + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + k_pack=k_pack, + policy=GemmWarpPolicy.FullRow, + ) + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = acc_s[i, j] * scale + + T.copy(m_i, m_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) + + for i in T.Parallel(block_M): + if m_prev[i] == -T.infinity(accum_dtype): + scale_factor[i] = 0.0 + else: + scale_factor[i] = T.exp(m_prev[i] - m_i[i]) + + l_i[i] *= scale_factor[i] + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scale_factor[i] + + for i, j in T.Parallel(block_M, block_N): + if acc_s[i, j] == -T.infinity(acc_s.dtype): + acc_s[i, j] = 0.0 + else: + acc_s[i, j] = T.exp(acc_s[i, j] - m_i[i]) + + T.reduce_sum(acc_s, row_sum, dim=1) + for i in T.Parallel(block_M): + l_i[i] += row_sum[i] + + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) + + l_inv = T.alloc_fragment([block_M], accum_dtype) + for i in T.Parallel(block_M): + safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0) + l_inv[i] = 1.0 / safe_l + + for i, j in T.Parallel(block_M, dim): + Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i] + + for i in T.Parallel(block_M): + if q_block_offset + i < seq_len: + lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) + LSE[bz, by, q_block_offset + i] = lse_val + + bx_loop_var = current_bx + num_split_q + + return main + + +def get_bwd_configs(): + block_M = [16, 32, 64, 128, 256] + block_N = [16, 32, 64, 128, 256] + threads = [64, 128, 256, 512, 1024] + num_stages = [0, 1, 2] + enable_rasterization = [True] + panel_size = [7, 8, 9, 10] + + configs = [] + for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + configs.append( + { + "block_M": m, + "block_N": n, + "num_stages": stages, + "threads": t, + "enable_rasterization": r, + "panel_size": p, + } + ) + + return configs + + +@tilelang.jit(out_idx=[2]) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): + with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) +@tilelang.jit +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + panel_size: int, +): + sm_scale = (1.0 / dim) ** 0.5 + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + T.use_swizzle(panel_size, enable=enable_rasterization) + + K_shared = T.alloc_shared([block_M, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + q_shared = T.alloc_shared([block_N, dim], dtype) + do_shared = T.alloc_shared([block_N, dim], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta_shared = T.alloc_shared([block_N], accum_dtype) + ds_shared = T.alloc_shared([block_M, block_N], dtype) + + p_cast = T.alloc_fragment([block_M, block_N], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + P_acc = T.alloc_fragment([block_M, block_N], accum_dtype) + dP = T.alloc_fragment([block_M, block_N], accum_dtype) + + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared) + T.clear(qkT) + + T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + + for i, j in T.Parallel(block_M, block_N): + P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0) + + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared) + T.clear(dP) + + T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(P_acc, p_cast) + T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) + + for i, j in T.Parallel(block_M, block_N): + p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale + + T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(p_cast, ds_shared) + T.clear(dq) + T.gemm(ds_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + + for i, j in T.Parallel(block_M, dim): + T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j]) + T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk[i, j]) + + return flash_bwd_kernel + + +@tilelang.jit(out_idx=[1]) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.copy( + dQ_in[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +def debug_tensor_comparison(tensor1, tensor2, name, rtol=1e-3, atol=1e-3): + print(f"\n=== {name} Comparison ===") + print(f"Shape: {tensor1.shape} vs {tensor2.shape}") + print(f"Data type: {tensor1.dtype} vs {tensor2.dtype}") + print(f"Device: {tensor1.device} vs {tensor2.device}") + + diff = torch.abs(tensor1 - tensor2) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + std_diff = diff.std().item() + + print(f"Max difference: {max_diff:.6f}") + print(f"Mean difference: {mean_diff:.6f}") + print(f"Difference std: {std_diff:.6f}") + + if max_diff > atol: + max_idx = torch.argmax(diff) + max_idx = np.unravel_index(max_idx.cpu().numpy(), tensor1.shape) + print(f"Max difference position: {max_idx}") + print(f"Value1: {tensor1[max_idx].item():.6f}, Value2: {tensor2[max_idx].item():.6f}") + + nan_count1 = torch.isnan(tensor1).sum().item() + nan_count2 = torch.isnan(tensor2).sum().item() + inf_count1 = torch.isinf(tensor1).sum().item() + inf_count2 = torch.isinf(tensor2).sum().item() + + print(f"NaN count: {nan_count1} vs {nan_count2}") + print(f"Inf count: {inf_count1} vs {inf_count2}") + + relative_diff = diff / (torch.abs(tensor2) + 1e-8) + max_relative_diff = relative_diff.max().item() + mean_relative_diff = relative_diff.mean().item() + + print(f"Max relative difference: {max_relative_diff:.6f}") + print(f"Mean relative difference: {mean_relative_diff:.6f}") + + close = torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol) + print(f"Within tolerance (rtol={rtol}, atol={atol}): {close}") + + return close, max_diff, mean_diff + + +def benchmark_function(func, *args, warmup=10, repeat=100): + for _ in range(warmup): + func(*args) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + times = [] + for _ in range(repeat): + start = time.time() + func(*args) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end = time.time() + times.append((end - start) * 1000) + + return np.median(times) + + +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): + device = "cuda" + dtype = torch.float16 + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}") + + flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 5 * flops_per_gemm + + print(f"Total FLOPs: {total_flops / 1e12:.2f} TFlops") + + q = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) + k = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype) + v = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype) + dO = torch.randn_like(q) + + print("Starting autotuning for Fast FlashAttention-V2 Forward Pass...") + fwd_kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups) + if fwd_kernel is None or fwd_kernel.config is None: + print("Forward pass auto-tuning failed.") + return + print(f"Autotuning finished. Best Forward Configuration: {fwd_kernel.config}") + + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + + profiler = fwd_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + print("Verifying correctness...") + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("Forward pass is correct.") + + o_tl, lse_tl = fwd_kernel(q, k, v) + + bwd_prep = flashattn_bwd_preprocess(batch, heads, seq_len, dim) + delta_tl = bwd_prep(o_tl, dO) + + print("\nStarting FlashAttention-V2 backward pass autotuning...") + bwd_kernel = flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups) + if bwd_kernel is None or bwd_kernel.config is None: + print("Backward pass autotuning failed.") + return + print(f"Autotuning completed. Best backward pass configuration: {bwd_kernel.config}") + + dQ_accum = torch.zeros_like(q, dtype=torch.float32) + dK_tl = torch.zeros_like(k, dtype=torch.float32) + dV_tl = torch.zeros_like(v, dtype=torch.float32) + + bwd_kernel(q, k, v, dO, lse_tl, delta_tl, dQ_accum, dK_tl, dV_tl) + + post_kernel = flashattn_bwd_postprocess(batch, heads, seq_len, dim) + dQ_tl = post_kernel(dQ_accum) + + q_ref = q.clone().detach().requires_grad_() + k_ref = k.clone().detach().requires_grad_() + v_ref = v.clone().detach().requires_grad_() + + o_ref, _ = ref_program(q_ref, k_ref, v_ref, is_causal, groups) + o_ref.backward(dO) + + print("Verifying backward pass correctness...") + dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) + if dq_close: + print("dQ is correct.") + else: + print("dQ mismatch detected.") + + dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) + if dk_close: + print("dK is correct.") + else: + print("dK mismatch detected.") + + dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) + if dv_close: + print("dV is correct.") + else: + print("dV mismatch detected.") + + print("\n=== Performance Benchmarking ===") + + def run_reference_fwd_bwd(): + q_ref_bench = q.clone().detach().requires_grad_() + k_ref_bench = k.clone().detach().requires_grad_() + v_ref_bench = v.clone().detach().requires_grad_() + + o_ref_bench, _ = ref_program(q_ref_bench, k_ref_bench, v_ref_bench, is_causal, groups) + + o_ref_bench.backward(dO) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) + print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops") + + def run_complete_fwd_bwd(): + o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) + + delta_tl_bench = bwd_prep(o_tl_bench, dO) + + dQ_bench = torch.zeros_like(q, dtype=torch.float32) + dK_bench = torch.zeros_like(k, dtype=torch.float32) + dV_bench = torch.zeros_like(v, dtype=torch.float32) + bwd_kernel(q, k, v, dO, lse_tl_bench, delta_tl_bench, dQ_bench, dK_bench, dV_bench) + + post_kernel(dQ_bench) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + tile_latency = benchmark_function(run_complete_fwd_bwd, warmup=10, repeat=100) + print( + f"Complete Flash Attention V2 Forward+Backward (Tile-lang): {tile_latency:.2f} ms | {total_flops / tile_latency * 1e-9:.2f} TFlops" + ) + + speedup = ref_latency / tile_latency + print(f"Speedup: {speedup:.2f}x") + + print("Forward output: Passed") + print(f"dQ: {'Passed' if dq_close else 'Failed'} (Max diff: {dq_max_diff:.6f})") + print(f"dK: {'Passed' if dk_close else 'Failed'} (Max diff: {dk_max_diff:.6f})") + print(f"dV: {'Passed' if dv_close else 'Failed'} (Max diff: {dv_max_diff:.6f})") + + if all([dq_close, dk_close, dv_close]): + print("All checks passed!") + else: + print("Some checks failed, may need further debugging.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") + args = parser.parse_args() + + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/tilelang/original/examples/amd/example_amd_flash_attn_fwd.py b/tilelang/original/examples/amd/example_amd_flash_attn_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9c361ff1235a3f7f49b2900b5c5ee868d92a2e --- /dev/null +++ b/tilelang/original/examples/amd/example_amd_flash_attn_fwd.py @@ -0,0 +1,246 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.tileop.base import GemmWarpPolicy +import itertools +import argparse +from functools import partial + + +# Custom supply function to ensure tensors are created on GPU +def supply_tensors_gpu(params): + """Supply function that creates tensors on GPU for ROCm/HIP.""" + tensors = [] + for param in params: + if hasattr(param, "shape") and hasattr(param, "dtype"): + # Force creation on GPU device + shape = [int(s) for s in param.shape] + tensor = torch.randn(shape, dtype=param.dtype, device="cuda") + tensors.append(tensor) + else: + tensors.append(param) + return tensors + + +def ref_program(Q, K, V, is_causal, groups=1): + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def get_configs(): + """Generates configurations for the autotuner, tailored for FA-2 style parallelism.""" + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + threads = [128, 256, 512] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + enable_rasterization = [True] + k_pack = [2] + panel_size = [7, 8] + qk_coalesced_width = [8] + v_coalesced_width = [4] + + valid_configs = [] + + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) + return valid_configs + + +@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu) +@tilelang.jit(out_idx=[3]) +def fast_flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_split_q: int, + threads: int, + num_stages: int, + enable_rasterization: bool, + k_pack: int, + panel_size: int, + qk_coalesced_width: int, + v_coalesced_width: int, +): + scale = (1.0 / dim) ** 0.5 + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + vec_size = qk_coalesced_width + v_vec_size = v_coalesced_width + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): + T.use_swizzle(panel_size, enable=enable_rasterization) + + bz = byz_combined // heads + by = byz_combined % heads + + num_q_blocks = T.ceildiv(seq_len, block_M) + + bx = T.alloc_var(T.int32) + bx = b_split + + with T.While(bx < num_q_blocks): + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + m_i = T.alloc_fragment([block_M], accum_dtype) + l_i = T.alloc_fragment([block_M], accum_dtype) + T.fill(acc_o, 0) + T.fill(m_i, -T.infinity(accum_dtype)) + T.fill(l_i, 0) + + current_bx = bx + q_block_offset = current_bx * block_M + + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + # Use register fragment for P instead of shared memory to reduce LDS usage + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + m_prev = T.alloc_fragment([block_M], accum_dtype) + scale_factor = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) + + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + + row_sum = T.alloc_fragment([block_M], accum_dtype) + + for k in T.Pipelined(loop_end_k, num_stages=num_stages): + kv_idx = k * block_N + + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + k_pack=k_pack, + policy=GemmWarpPolicy.FullRow, + ) + + T.copy(m_i, m_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) + + for i in T.Parallel(block_M): + sf = T.exp(m_prev[i] * scale - m_i[i] * scale) + l_i[i] *= sf + scale_factor[i] = sf + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scale_factor[i] + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale) + + T.reduce_sum(acc_s, row_sum, dim=1) + for i in T.Parallel(block_M): + l_i[i] += row_sum[i] + + # Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) + + l_inv = T.alloc_fragment([block_M], accum_dtype) + for i in T.Parallel(block_M): + safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0) + l_inv[i] = 1.0 / safe_l + + for i, j in T.Parallel(block_M, dim): + Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i] + + bx = current_bx + num_split_q + + return main + + +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + print("Starting autotuning for FlashAttention-V2...") + kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups) + print(f"Autotuning finished. Best Configuration: {kernel.config}") + + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + print("Verifying correctness...") + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program_processed, warmup=100) + print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") + + latency = profiler.do_bench(warmup=100) + print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/tilelang/original/examples/amd/main.o b/tilelang/original/examples/amd/main.o new file mode 100644 index 0000000000000000000000000000000000000000..304e621fa26ac84e9bc321baa3ae6c1b9f12b186 Binary files /dev/null and b/tilelang/original/examples/amd/main.o differ diff --git a/tilelang/original/examples/amd/main.o.d b/tilelang/original/examples/amd/main.o.d new file mode 100644 index 0000000000000000000000000000000000000000..157c1a0d3f92ca0d9926552912d40f00c04be8a4 --- /dev/null +++ b/tilelang/original/examples/amd/main.o.d @@ -0,0 +1,3 @@ +main.o: /root/.cache/torch_extensions/py310_cpu/c_dlpack/main.cpp \ + /usr/local/lib/python3.10/dist-packages/tvm_ffi/include/dlpack/dlpack.h \ + /usr/local/lib/python3.10/dist-packages/fastpt/torch/include/c10/cuda/CUDAStream.h diff --git a/tilelang/original/examples/analyze/README.md b/tilelang/original/examples/analyze/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ec0a687547701d6e3a125620306d54fb4e85a9c --- /dev/null +++ b/tilelang/original/examples/analyze/README.md @@ -0,0 +1,111 @@ +# TVM IR Performance Analyzer + +A performance analysis toolkit for TVM IR modules, Provides hardware-aware performance metrics including FLOPs, memory bandwidth utilization, and execution time estimation. + +## Features + +- ​**Operation Analysis**: Supports arbitrary operations expressed in TVM IR (including GEMM and convolution) +- ​**Memory Traffic Calculation**: Tracks global memory transfers +- ​**Architecture-aware Metrics**: Pre-configured with NVIDIA GPU architectures (Ampere, Ada Lovelace) +- ​**Performance Estimation**: Predicts execution time using roofline model +- ​**TVM Integration**: Works with TVM IRModule and PrimFunc + +## Quick Start +### GEMM Analysis Example +```python +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA + +M = N = K = 1024 + +def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128): + @T.prim_func + def main(A: T.Tensor((M, K), T.float16), + B: T.Tensor((N, K), T.float16), + C: T.Tensor((M, N), T.float)): + # ... (kernel definition) + return main + +cuda_device = CUDA("cuda") +result = Analyzer.analysis(kernel(), cuda_device) +print(result) +``` + +### Convolution Analysis Example +```python +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA + +def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128): + @T.prim_func + def main(data: T.Tensor((N, H, W, C), T.float16), + kernel: T.Tensor((K, K, C, F), T.float16), + out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)): + # ... (convolution kernel definition) + return main + +cuda_device = CUDA("cuda") +result = Analyzer.analysis(kernel(), cuda_device) +print(result) +``` + +## API Documentation +### `AnalysisResult` Class +```python +@dataclass(frozen=True) +class AnalysisResult: + total_flops: int # Total floating-point operations + total_global_bytes: int # Global memory traffic in bytes + estimated_time: float # Predicted execution time (seconds) + tflops: float # Achieved TFLOPS + bandwidth_GBps: float # Memory bandwidth utilization +``` +### `Analyzer` Class Methods +#### `analysis(fn, device)` +* ​Parameters: + * fn: TVM IRModule or PrimFunc + * device: Device configuration object +* Returns: AnalysisResult +#### Supported Architectures +```python +# Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count) +ARCH_CONFIGS = { + "80": (128, 1.41, 2, 108), # A100 + "86": (128, 1.70, 2, 84), # RTX 3080 + "89": (128, 2.52, 2, 128) # RTX 4090 +} +``` + +## Implementation Details + +### Performance Model +Uses roofline model with two constraints: +1. ​**Compute Bound**: `Time = Total FLOPs / (SM Count × Cores/SM × Clock × FLOPs/Cycle)` +2. ​**Memory Bound**: `Time = Memory Bytes / (Bandwidth × Utilization)` + +### IR Analysis Pass +1. ​**Traversal**: Walks through TVM IR using `ir_transform` +2. ​**Operation Detection**: + - Counts FLOPs for all compute operations + - Calculates memory traffic for all memory operations +3. ​**Loop Handling**: + - Tracks nested loops for operation scaling + - Accounts for block/grid dimensions + +## Key Metrics Calculation + +| Metric | Formula | +|-------------------------|-----------------------------------------| +| FLOPs per GEMM | `2 × M × N × K` | +| Memory Traffic per Copy | `elements × dtype_size × loop_product` | +| Achieved TFLOPS | `total_flops / estimated_time / 1e12` | +| Memory Bandwidth | `total_global_bytes / estimated_time` | + +## Limitations +1. Requires memory operations to be properly annotated in the IR +2. Assumes perfect memory coalescing and no bank conflicts + +## Supported Operations +Any operation expressed in TVM IR diff --git a/tilelang/original/examples/analyze/example_conv_analyze.py b/tilelang/original/examples/analyze/example_conv_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..db21e02f62bc0ad281848a8085d6b661d2d4e93c --- /dev/null +++ b/tilelang/original/examples/analyze/example_conv_analyze.py @@ -0,0 +1,89 @@ +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +from tilelang.layout import make_swizzled_layout +import torch + +N = 64 +C = 256 +H = 512 +W = 512 +F = 512 +K = 3 +S = 1 +D = 1 +P = 1 + + +def check_hopper(): + # if not torch.cuda.is_available(): + # return None + # props = torch.cuda.get_device_properties(0) + # compute_capability = props.major, props.minor + # return compute_capability == (9, 0) + return False + + +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = T.float16 + accum_dtype = T.float32 + is_hopper = check_hopper() + + @T.prim_func + def conv( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout( + { + out_shared: make_swizzled_layout(out_shared), + data_shared: make_swizzled_layout(data_shared), + kernel_shared: make_swizzled_layout(kernel_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return conv + + +def main(): + my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) + cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + result = Analyzer.analysis(my_func, cuda_device) + print(result) + print(f"Analyzed FLOPs: {result.total_flops}") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/analyze/example_gemm_analyze.py b/tilelang/original/examples/analyze/example_gemm_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..0367af126e04d631d53c9d63eb6f269b807dcf2d --- /dev/null +++ b/tilelang/original/examples/analyze/example_gemm_analyze.py @@ -0,0 +1,60 @@ +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +import torch + +M = N = K = 1024 + + +def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, +): + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return matmul + + +def main(): + my_func = kernel(128, 128, 32, 3, 128, True) + + cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + result = Analyzer.analysis(my_func, cuda_device) + + print(f"Analyzed FLOPs: {result.total_flops}") + print(f"Expected FLOPs: {2 * M * N * K}") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/analyze/test_example_analyze.py b/tilelang/original/examples/analyze/test_example_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..448844b900270b61ca452d3acc25104d061d1492 --- /dev/null +++ b/tilelang/original/examples/analyze/test_example_analyze.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_gemm_analyze +import example_conv_analyze + + +def test_example_gemm_analyze(): + example_gemm_analyze.main() + + +def test_example_conv_analyze(): + example_conv_analyze.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/attention_sink/README.md b/tilelang/original/examples/attention_sink/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ed4b7004e6283c7c2b7a5cdeffbf7da90e1dcca4 --- /dev/null +++ b/tilelang/original/examples/attention_sink/README.md @@ -0,0 +1,46 @@ +# Attention Sink + +We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py). + + +## Algorithm +### Forward +The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage. + +### Backward +Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by: + +$$ +dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q} +$$ + +where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row). + +## Benchmark of forward process + +### Benchmark Environment +- **Hardware**: NVIDIA H800 +- **CUDA version**: 12.9 +- **Triton Version**: 3.4.0 + +### Results + +- dtype=bfloat16 +- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B) +- Full attention is adopted. + +| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup | +|---------|---------|---------------|----------------------|---------| +| 2048 | 64 | 232.98 | **281.89** | 1.21x | +| 2048 | 128 | 321.55 | **417.98** | 1.30x | +| | | | | | +| 4096 | 64 | 280.70 | **349.47** | 1.25x | +| 4096 | 128 | 369.61 | **497.13** | 1.35x | +| | | | | | +| 8192 | 64 | 299.04 | **385.56** | 1.29x | +| 8192 | 128 | 399.39 | **507.93** | 1.27x | +| | | | | | +| 16384 | 64 | 309.46 | **400.62** | 1.29x | +| 16384 | 128 | 418.99 | **549.11** | 1.31x | + +> The backward performance will be further optimized in the future. \ No newline at end of file diff --git a/tilelang/original/examples/attention_sink/benchmark_gqa_sink_fwd.py b/tilelang/original/examples/attention_sink/benchmark_gqa_sink_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..211ef1d18cda28f20c6104b17f0330322a437d3f --- /dev/null +++ b/tilelang/original/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -0,0 +1,211 @@ +import torch +import argparse +from tilelang.profiler import do_bench +from tilelang import language as T +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor +from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from typing import Optional + + +@triton.jit +def triton_kernel( + Q, + K, + V, + Sinks, + sm_scale, + Out, + Z, + H, + N_Q_CTX, + N_KV_CTX, + HEAD_DIM: tl.constexpr, + groups: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BANDWIDTH: tl.constexpr, + start_q: tl.constexpr, +): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + # load attention sinks + if Sinks is not None: # noqa: SIM108 + sink = tl.load(Sinks + off_h).to(tl.float32) + else: + sink = 0 + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) + + if BANDWIDTH: + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + else: + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None] + + if BANDWIDTH: + too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) + mask = mask | too_old + + k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T + qk = tl.dot(q, k, allow_tf32=False) + + qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + + p = tl.math.exp(qk) + alpha = tl.math.exp(m_i - m_ij) + l_ij = tl.sum(p, 1) + acc = acc * alpha[:, None] + + v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]) + # v = v.to(tl.float32) + p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core + acc = tl.dot(p, v, acc, allow_tf32=False) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + sink = tl.math.exp(sink - m_i) + z = l_i + sink + acc = acc / z[:, None] + # m_i += tl.math.log(l_i) + # m_ptrs = M + off_hz * N_Q_CTX + offs_m + # tl.store(m_ptrs, m_i) + acc = acc.to(Out.dtype)[None, None, :, :] + Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) + + +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor: + bs, n_heads, seq_q, head_dim = Q.shape + _, n_heads_kv, seq_kv, _ = K.shape + BLOCK_M = 64 + BLOCK_N = 64 + groups = n_heads // n_heads_kv + + o = torch.empty_like(Q) + grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1) + triton_kernel[grid]( + TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]), + TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]), + TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]), + Sinks, + 1.0 / head_dim**0.5, + TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]), + bs, + n_heads, + N_Q_CTX=seq_q, + N_KV_CTX=seq_kv, + HEAD_DIM=head_dim, + groups=groups, + BANDWIDTH=window_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + start_q=seq_kv - seq_q, + ) + return o + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + + if torch.allclose( + triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ): + print("Checks for triton passed.✅") + else: + print("Checks for triton failed.❌") + + # Benchmark triton + latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) + print("Triton: {:.2f} ms".format(latency_triton)) + print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9)) + + # Benchmark tilelang + latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency_tilelang)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) + + print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/tilelang/original/examples/attention_sink/benchmark_mha_sink_fwd.py b/tilelang/original/examples/attention_sink/benchmark_mha_sink_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..50747e6b09d902668ec99ee5c267c9dccadf208f --- /dev/null +++ b/tilelang/original/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -0,0 +1,198 @@ +import torch +import argparse +from tilelang.profiler import do_bench +from tilelang import language as T +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor +from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from typing import Optional + + +@triton.jit +def triton_kernel( + Q, + K, + V, + Sinks, + sm_scale, + Out, + Z, + H, + N_Q_CTX, + N_KV_CTX, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BANDWIDTH: tl.constexpr, + start_q: tl.constexpr, +): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + # load attention sinks + if Sinks is not None: # noqa: SIM108 + sink = tl.load(Sinks + off_h).to(tl.float32) + else: + sink = 0 + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) + + if BANDWIDTH: + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + else: + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None] + + if BANDWIDTH: + too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) + mask = mask | too_old + + k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T + qk = tl.dot(q, k, allow_tf32=False) + + qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + + p = tl.math.exp(qk) + alpha = tl.math.exp(m_i - m_ij) + l_ij = tl.sum(p, 1) + acc = acc * alpha[:, None] + + v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]) + # v = v.to(tl.float32) + p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core + acc = tl.dot(p, v, acc, allow_tf32=False) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + sink = tl.math.exp(sink - m_i) + z = l_i + sink + acc = acc / z[:, None] + # m_i += tl.math.log(l_i) + # m_ptrs = M + off_hz * N_Q_CTX + offs_m + # tl.store(m_ptrs, m_i) + acc = acc.to(Out.dtype)[None, None, :, :] + Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) + + +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor: + bs, n_heads, seq_q, head_dim = Q.shape + seq_kv = K.shape[2] + BLOCK_M = 64 + BLOCK_N = 64 + + o = torch.empty_like(Q) + grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1) + triton_kernel[grid]( + TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]), + TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]), + TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]), + Sinks, + 1.0 / head_dim**0.5, + TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]), + bs, + n_heads, + N_Q_CTX=seq_q, + N_KV_CTX=seq_kv, + HEAD_DIM=head_dim, + BANDWIDTH=window_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + start_q=seq_kv - seq_q, + ) + return o + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.✅") + + latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) + print("Triton: {:.2f} ms".format(latency)) + print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/tilelang/original/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/tilelang/original/examples/attention_sink/example_gqa_sink_bwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..541baca0430a4378220ce8b48868e64d4014e5dc --- /dev/null +++ b/tilelang/original/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -0,0 +1,512 @@ +# Adapted from tilelang/examples/flash_attention/example_gqa_bwd.py + +import torch +import tilelang +from tilelang.profiler import do_bench +import tilelang.language as T +import argparse +from typing import Optional + + +def get_bwd_configs(): + sm_major, sm_minor = torch.cuda.get_device_capability() + sm_version = sm_major * 10 + sm_minor + if sm_version == 80: + return 64, 32, 1, 128 + elif sm_version == 90: + return 128, 32, 2, 256 + else: + raise ValueError(f"Unsupported SM version: {sm_version}") + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd( + batch, + heads, + seq_len, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, heads, seq_len, dim] + kv_shape = [batch, head_kv, seq_len, dim] + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + Output: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([heads], dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, heads, seq_len, dim] + kv_shape = [batch, head_kv, seq_len, dim] + accum_dtype = T.float32 + + block_M, block_N, num_stages, threads = get_bwd_configs() + + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + dO: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(kv_shape, accum_dtype), # type: ignore + dV: T.Tensor(kv_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared) + T.clear(dv) + T.clear(dk) + + loop_st = T.floordiv(by * block_M, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + for i, j in T.Parallel(block_M, block_N): + if window_size is not None: + qkT[i, j] = T.if_then_else( + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) + else: + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) + + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(out_idx=-1) +def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len] + + @T.prim_func + def flash_bwd_dsink( + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): + sink = T.alloc_local([1], dtype) + lse_fragment = T.alloc_fragment([block], accum_dtype) + delta_fragment = T.alloc_fragment([block], accum_dtype) + dsink_fragment = T.alloc_fragment([block], dtype) + + sink[0] = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) + for i in T.Parallel(block): + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) + + return flash_bwd_dsink + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sinks, window_size, groups): + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)] + BATCH, H, N_CTX, D_HEAD = q.shape + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + o, lse = kernel(q, k, v, sinks) + ctx.save_for_backward(q, k, v, sinks, o, lse) + ctx.window_size = window_size + ctx.groups = groups + return o + + @staticmethod + def backward(ctx, do): + q, k, v, sinks, o, lse = ctx.saved_tensors + BATCH, H, N_CTX, D_HEAD = q.shape + groups = ctx.groups + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size, dtype=dtype) + q_shape = [BATCH, H, N_CTX, D_HEAD] + head_kv = H // groups + kv_shape = [BATCH, head_kv, N_CTX, D_HEAD] + dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd + dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device) + dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + + kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) + return dq, dk, dv, dsinks, None, None + + +attention = _attention.apply + + +# Adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim) + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + + start_q = num_keys - num_queries + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def main( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= N_CTX + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 + total_flops = 5 * flops_per_matmul + + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + V = torch.randn_like(K).requires_grad_() + sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() + dO = torch.randn_like(Q) + + O = attention(Q, K, V, sinks, window_size, groups) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + dsinks, sinks.grad = sinks.grad.clone(), None + + O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + dsinks_ref, sinks.grad = sinks.grad.clone(), None + + # Checks + rtol, atol = { + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), + }[dtype] + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" + + print("All checks passed for tilelang kernels.✅") + + # Only benchmark backward here + def torch_bwd(): + O_ref.backward(dO, retain_graph=True) + + def tl_bwd(): + O.backward(dO, retain_graph=True) + + latency = do_bench(torch_bwd, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(tl_bwd, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--groups", type=int, default=8, help="Groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/tilelang/original/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/tilelang/original/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..df157cd0ff396c3f5e358c71334dc77695e72315 --- /dev/null +++ b/tilelang/original/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,332 @@ +# Modified from tilelang/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl + +import torch +import tilelang +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import itertools +import argparse +from typing import Optional + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_configs(), + warmup=500, + rep=100, +) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, head_kv, seq_kv, dim] + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined( + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +# Following functions are adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim) + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + + start_q = num_keys - num_queries + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") + return query, key, value, sinks + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.✅") + + # Benchmark tilelang + latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency_tilelang)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/tilelang/original/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/tilelang/original/examples/attention_sink/example_mha_sink_bwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..be405e8bc3c986d0b3241d650c1ab1652e1081ec --- /dev/null +++ b/tilelang/original/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -0,0 +1,505 @@ +# Adapted from tilelang/examples/flash_attention/example_mha_bwd_bhsd.py + +import torch +import tilelang +from tilelang.profiler import do_bench +import tilelang.language as T +import argparse +from typing import Optional + + +def get_bwd_configs(): + sm_major, sm_minor = torch.cuda.get_device_capability() + sm_version = sm_major * 10 + sm_minor + if sm_version == 80: + return 64, 32, 1, 128 + elif sm_version == 90: + return 128, 32, 2, 256 + else: + raise ValueError(f"Unsupported SM version: {sm_version}") + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd( + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention, + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + shape = [batch, heads, seq_len, dim] + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([heads], dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention + sm_scale=None, + dtype: T.dtype = T.float16, +): + block_M, block_N, num_stages, threads = get_bwd_configs() + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + shape = [batch, heads, seq_len, dim] + accum_dtype = T.float32 + + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) + T.clear(dv) + T.clear(dk) + + loop_st = T.floordiv(by * block_M, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + for i, j in T.Parallel(block_M, block_N): + if window_size is not None: + qkT[i, j] = T.if_then_else( + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) + else: + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) + + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) + + return flash_bwd + + +@tilelang.jit(out_idx=-1) +def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len] + + @T.prim_func + def flash_bwd_dsink( + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): + sink = T.alloc_local([1], dtype) + lse_fragment = T.alloc_fragment([block], accum_dtype) + delta_fragment = T.alloc_fragment([block], accum_dtype) + dsink_fragment = T.alloc_fragment([block], accum_dtype) + + sink[0] = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) + for i in T.Parallel(block): + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) + + return flash_bwd_dsink + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sinks, window_size): + BATCH, H, N_CTX, D_HEAD = q.shape + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) + o, lse = kernel(q, k, v, sinks) + ctx.save_for_backward(q, k, v, sinks, o, lse) + ctx.window_size = window_size + return o + + @staticmethod + def backward(ctx, do): + q, k, v, sinks, o, lse = ctx.saved_tensors + BATCH, H, N_CTX, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size, dtype=dtype) + shape = [BATCH, H, N_CTX, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd + dk = torch.empty(shape, dtype=q.dtype, device=q.device) + dv = torch.empty(shape, dtype=q.dtype, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + + kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) + return dq, dk, dv, dsinks, None + + +attention = _attention.apply + + +# Adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + start_q = num_keys - num_queries + + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1) + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= N_CTX + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 + total_flops = 5 * flops_per_matmul + + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn_like(Q).requires_grad_() + V = torch.randn_like(Q).requires_grad_() + sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() + dO = torch.randn_like(Q) + + O = attention(Q, K, V, sinks, window_size) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + dsinks, sinks.grad = sinks.grad.clone(), None + + O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + dsinks_ref, sinks.grad = sinks.grad.clone(), None + + # Checks + rtol, atol = { + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), + }[dtype] + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" + + print("All checks passed for tilelang kernels.✅") + + # Only benchmark backward here + def torch_bwd(): + O_ref.backward(dO, retain_graph=True) + + def tl_bwd(): + O.backward(dO, retain_graph=True) + + latency = do_bench(torch_bwd, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(tl_bwd, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) diff --git a/tilelang/original/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/tilelang/original/examples/attention_sink/example_mha_sink_fwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..f6754bd94acf6fe9ca440cc9058ed7080b5ed267 --- /dev/null +++ b/tilelang/original/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -0,0 +1,315 @@ +# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd.py + +import torch +import tilelang +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import itertools +import argparse +from typing import Optional + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=500, rep=100) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + start_q = num_keys - num_queries + + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") + return query, key, value, sinks + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.✅") + + latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/tilelang/original/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/tilelang/original/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..ecaf2ce33941587ceeefc49627424f58adeca276 --- /dev/null +++ b/tilelang/original/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,322 @@ +# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl + +import torch +import tilelang +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import itertools +import argparse +from typing import Optional + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=500, rep=100) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined( + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +# Following functions are adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + start_q = num_keys - num_queries + + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") + return query, key, value, sinks + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.✅") + + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/tilelang/original/examples/attention_sink/test_example_attention_sink.py b/tilelang/original/examples/attention_sink/test_example_attention_sink.py new file mode 100644 index 0000000000000000000000000000000000000000..57242c199c4b70345fc3fe29d3559da18e4ac990 --- /dev/null +++ b/tilelang/original/examples/attention_sink/test_example_attention_sink.py @@ -0,0 +1,65 @@ +import tilelang.testing + +import example_mha_sink_fwd_bhsd +import example_mha_sink_fwd_bhsd_wgmma_pipelined +import example_gqa_sink_fwd_bhsd_wgmma_pipelined +import example_mha_sink_bwd_bhsd +import example_gqa_sink_bwd_bhsd + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_fwd_bhsd_full_attn(): + example_mha_sink_fwd_bhsd.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_fwd_bhsd_sliding_window(): + example_mha_sink_fwd_bhsd.main(window_size=128) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_full_attn(): + example_mha_sink_fwd_bhsd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + example_mha_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_full_attn(): + example_gqa_sink_fwd_bhsd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + example_gqa_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128) + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_bwd_bhsd(): + example_mha_sink_bwd_bhsd.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_bwd_bhsd_sliding_window(): + example_mha_sink_bwd_bhsd.main(window_size=128) + + +@tilelang.testing.requires_cuda +def test_example_gqa_sink_bwd_bhsd(): + example_gqa_sink_bwd_bhsd.main() + + +@tilelang.testing.requires_cuda +def test_example_gqa_sink_bwd_bhsd_sliding_window(): + example_gqa_sink_bwd_bhsd.main(window_size=128) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/bitnet-1.58b/.gitignore b/tilelang/original/examples/bitnet-1.58b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6ea8874968d000cd47f52f55f32a92f0127532b3 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/.gitignore @@ -0,0 +1 @@ +models/ \ No newline at end of file diff --git a/tilelang/original/examples/bitnet-1.58b/README.md b/tilelang/original/examples/bitnet-1.58b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2b587eab4cc6128965be2e4cac4a5d68db13a86c --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/README.md @@ -0,0 +1,97 @@ +--- +license: mit +--- + + +This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. + +## Make Checkpoints for vLLM + +We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. + +```bash +# move to the integration directory +cd /root/to/BitBLAS/integration/BitNet +# make the checkpoint +./maint/generate_bitnet_model_native_format.sh +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory +``` + +The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. + +```bash +./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory +``` + +Finnaly, you can use the ckpt in vLLM with: + +```bash +cd vllm_workspace +# inference with the ckpt with fp16 uncompressed metadata +python3 inference_with_native_format.py +# inference with the ckpt with BitBLAS compressed metadata +python3 inference_with_bitblas_format.py +``` + +**Benchmark results of vLLM** + +| Model | Framework | BS16IN32OUT128 | BS1IN512OUT1024 | BS32IN32OUT128 | +|------------------------|--------------------------|----------------|-----------------|----------------| +| bitnet-3b-1.58bits | pytorch | 106.83 | 49.34 | 209.03 | +| bitnet-3b-1.58bits | pytorch-tilelang | 240.33 | 103.09 | 493.31 | +| bitnet-3b-1.58bits | vllm-tilelang | 379.25 | 117.43 | 752.55 | +| bitnet-3b-1.58bits | vllm-tilelang-cuda-graph | 2543.58 | 1621.08 | 2731.79 | + + +## BitBLAS Results + +### Performance + +**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo. + +| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas | +|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 | +| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 | +| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 | + +### On-the-Fly GPU Memory Footprint + +We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage. + +| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** | +|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB | +| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB | +| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB | + +## PPL and Zero-shot Accuracy + +The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`. + +PPL and zero-shot accuracy: +| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg +|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 | +| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 | +| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 | +| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2 +| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 | +| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9 +| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7 +| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2 +| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 | + +The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors. + +## Citations + +```bibtex +@article{ma2024era, + title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits}, + author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu}, + journal={arXiv preprint arXiv:2402.17764}, + year={2024} +} +``` \ No newline at end of file diff --git a/tilelang/original/examples/bitnet-1.58b/benchmark.sh b/tilelang/original/examples/bitnet-1.58b/benchmark.sh new file mode 100755 index 0000000000000000000000000000000000000000..6a2550d45562387677cf169ae66744fcd6a8657e --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/benchmark.sh @@ -0,0 +1,11 @@ +python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log + +python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log + +python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 | tee b32_i32_o128.log + +python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b16_i32_o128_bitblas.log + +python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 --bitblas | tee b1_i512_o64_bitblas.log + +python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b32_i32_o128_bitblas.log diff --git a/tilelang/original/examples/bitnet-1.58b/benchmark_generate.py b/tilelang/original/examples/bitnet-1.58b/benchmark_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..d678b91a4e1c970e2209d2dfc0a102af4c3cf81b --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/benchmark_generate.py @@ -0,0 +1,114 @@ +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time +import argparse + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + + +def generate_text_batch(model, tokenizer, prompts, max_length=100): + # Encode the input prompts as a batch + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + + # Generate cos and sin values (commented out as not used in generation) + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # position_embeddings = model.embed_positions(position_ids) + # cos = position_embeddings[:, :, 0::2].cos() + # sin = position_embeddings[:, :, 1::2].sin() + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) + end_time = time.time() + + # Decode the output ids to text + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] + + generation_time = end_time - start_time + num_tokens = sum(len(output_id) for output_id in output_ids) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_texts + + +def profile(model, input_data): + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +model_path = "1bitLLM/bitnet_b1_58-3B" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--bs", default=16, type=int) + parser.add_argument("--in_seq_len", default=32, type=int) + parser.add_argument("--out_seq_len", default=128, type=int) + parser.add_argument("--bitblas", action="store_true") + args = parser.parse_args() + bs = args.bs + in_seq_len = args.in_seq_len + out_seq_len = args.out_seq_len + is_bitblas = args.bitblas + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + if is_bitblas: + with torch.no_grad(): + model.quantize() + + tokenizer = BitnetTokenizer.from_pretrained(model_path) + prompt = "" + for _ in range(in_seq_len): + prompt += "Hello " + + prompts = [] + for _ in range(bs): + prompts.append(prompt) + max_length = out_seq_len + in_seq_len + print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/benchmark_inference_latency.py b/tilelang/original/examples/bitnet-1.58b/benchmark_inference_latency.py new file mode 100644 index 0000000000000000000000000000000000000000..788fc5565d5d58b59ef11a11b33f357e911ba9bc --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/benchmark_inference_latency.py @@ -0,0 +1,57 @@ +import argparse +import torch + +from modeling_bitnet import BitnetForCausalLM + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) + + +def profile(model, input_data): + import time + + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +def main(): + model = BitnetForCausalLM.from_pretrained( + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", + low_cpu_mem_usage=True, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).half() + with torch.no_grad(): + model.quantize() + model = torch.compile(model) + + benchmark_sets = [(1024, 1), (1, 2048)] + for batch_size, seq_len in benchmark_sets: + input_id = torch.ones(batch_size, seq_len).long().cuda() + latency = profile(model, input_id) + print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/benchmark_model_10k_loops.py b/tilelang/original/examples/bitnet-1.58b/benchmark_model_10k_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..306c88428277b591bff935be701e5401a8faaf54 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/benchmark_model_10k_loops.py @@ -0,0 +1,63 @@ +import argparse +import torch + +from modeling_bitnet import BitnetForCausalLM + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--batch_size", default=1, type=int) +parser.add_argument("--seq_len", default=1, type=int) + +args = parser.parse_args() + +seq_len = args.seq_len +batch_size = args.batch_size + + +def profile(model, input_data): + import time + + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +def main(): + model = BitnetForCausalLM.from_pretrained( + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", + low_cpu_mem_usage=True, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).half() + with torch.no_grad(): + model._post_process_weights() + + torch.cuda.empty_cache() + + input_id = torch.ones(batch_size, seq_len).long().cuda() + for _ in range(10000): + _ = model(input_id) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/configuration_bitnet.py b/tilelang/original/examples/bitnet-1.58b/configuration_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..63c499db36d96d50f567794bf80a60882e08114f --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/configuration_bitnet.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class BitnetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BitnetModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BitnetModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Bitnet 1 supports up to 2048 tokens, + Bitnet 2 up to 4096, CodeBitnet up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import BitnetModel, BitnetConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = BitnetConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = BitnetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + weight_bits=1, + input_bits=8, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.weight_bits = weight_bits + self.input_bits = input_bits + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/tilelang/original/examples/bitnet-1.58b/eval_correctness.py b/tilelang/original/examples/bitnet-1.58b/eval_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..11d47004b81edf517d442cb0eb2b70e6c583cce0 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/eval_correctness.py @@ -0,0 +1,99 @@ +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time +import transformers + +print(f"transformers version is {transformers.__version__}") + +# version must be lower than or equal to 4.40 +assert transformers.__version__ <= "4.40.0" + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def profile(model, input_data): + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +model_path = "1bitLLM/bitnet_b1_58-3B" + + +def main(): + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + + tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) + input_id = tokenizer("Hello")["input_ids"] + input_id = torch.tensor(input_id).unsqueeze(0).cuda() + + print("original model generated text:") + print(generate_text(model, tokenizer, "Hello", max_length=100)) + + model.quantize() + print("quantized model generated text:") + print(generate_text(model, tokenizer, "Hello", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/eval_gpu_memory.py b/tilelang/original/examples/bitnet-1.58b/eval_gpu_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..00c914cb31c919fc536d0705f59cacf29a30e287 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/eval_gpu_memory.py @@ -0,0 +1,52 @@ +import argparse +import torch + +from modeling_bitnet import BitnetForCausalLM + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) + + +def profile(model, input_data): + import time + + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +def main(): + model = BitnetForCausalLM.from_pretrained( + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", + low_cpu_mem_usage=True, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).half() + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") + with torch.no_grad(): + model._post_process_weights() + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/eval_ppl.py b/tilelang/original/examples/bitnet-1.58b/eval_ppl.py new file mode 100644 index 0000000000000000000000000000000000000000..97db2d0f5236f369a33f70ac1b07fe9a8c01df9d --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/eval_ppl.py @@ -0,0 +1,72 @@ +# pylint: disable=missing-docstring, invalid-name +"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py.""" + +import math +import argparse +import torch +import random + +from eval_utils import get_test_dataset +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer + +from tqdm import tqdm + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", default=0, type=int) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--seqlen", default=2048, type=int) + + +def calulate_loss(model, input, loss_fct): + output = model(input, use_cache=False, output_hidden_states=False, output_attentions=False)[0] + shift_logits = output[:, :-1, :].contiguous() + shift_labels = input[:, 1:] + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + return loss + + +def main(args): + datasets = ["c4", "wikitext2"] + model = ( + BitnetForCausalLM.from_pretrained( + args.hf_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + with torch.no_grad(): + model._post_process_weights() + tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) + loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda() + + ppl = [] + for dataset in datasets: + testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen) + acc_loss, count = 0.0, 0 + progress = tqdm(range(len(testdata))) + for ii in progress: + input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) + loss = calulate_loss(model, input, loss_fct) + count += input.size(-1) - 1 + acc_loss += loss.item() + progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}") + + avg_loss = acc_loss / count / math.log(2) + ppl.append(2**avg_loss) + print("{} PPL: {}".format(dataset, ppl[-1])) + + print(ppl) + print("Avg PPL:", sum(ppl) / len(ppl)) + + +if __name__ == "__main__": + torch.set_grad_enabled(False) + args = parser.parse_args() + random.seed(args.seed) + torch.random.manual_seed(args.seed) + main(args) diff --git a/tilelang/original/examples/bitnet-1.58b/eval_utils.py b/tilelang/original/examples/bitnet-1.58b/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72480c392a7cfa40081546d2da19aa31463aab76 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/eval_utils.py @@ -0,0 +1,135 @@ +# ruff: noqa +import torch + +import numpy as np +import torch.nn.functional as F + +from lm_eval.base import BaseLM +from datasets import load_dataset + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_test_dataset(dataset_name, tokenizer, seqlen=2048): + if dataset_name == "wikitext2": + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testdata = "".join(testdata["text"]).split("\n") + elif dataset_name == "c4": + testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[ + "text" + ] + else: + raise NotImplementedError + + testdata = [item for item in testdata if item != ""] + tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata] + + data, doc = [], [tokenizer.bos_token_id] + for sen in tokenized_text: + if len(sen) > seqlen: + continue + if len(doc) + len(sen) > seqlen: + data.append(doc) + doc = [tokenizer.bos_token_id] + doc.extend(sen) + if len(doc) > 1 and len(doc) <= seqlen: + data.append(doc) + return data + + +class LMEvalAdaptor(BaseLM): + def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): + super().__init__() + + assert isinstance(batch_size, int) + + self.model_name = model_name + self.model = model + self.model.eval() + + self.tokenizer = tokenizer + + self.vocab_size = self.tokenizer.vocab_size + + self._batch_size = batch_size + + self._max_length = max_length + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length != -1: + return self._max_length + if hasattr(self.model.config, "n_ctx"): + return self.model.config.n_ctx + elif hasattr(self.model.config, "max_position_embeddings"): + return self.model.config.max_position_embeddings + elif hasattr(self.model.config, "n_positions"): + return self.model.config.n_positions + elif "bloom" in self.model_name: + return 2048 + elif "llama" in self.model_name: + return 2048 # TODO: did not check this + elif "mpt" in self.model_name: + return 2048 + elif "falcon" in self.model_name: + return 2048 + else: + print(self.model.config) + raise NotImplementedError + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self._batch_size + + @property + def device(self): + return "cuda" + + def tok_encode(self, string: str, add_special_tokens=True): + return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests): + new_reqs = [] + for context, continuation in requests: + context, continuation = context.strip(), continuation.strip() + if context == "": + # end of text as context + context_enc = [self.eot_token_id] + else: + context_enc = self.tok_encode(context, add_special_tokens=True) + + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs) + + def _model_call(self, inps): + """ + inps: a torch tensor of shape [batch, sequence] + the size of sequence may vary from call to call + + returns: a torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model + """ + with torch.no_grad(): + out = self.model(inps)[0] + return out + + def _model_generate(self, context, max_length, eos_token_id): + return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8b7b95cdb24f1bba466a8e776796b7ab025315 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -0,0 +1,262 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +import numpy as np + +from tilelang.transform import simplify_prim_func + +torch.manual_seed(42) + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + + +@simplify_prim_func +def bitnet_158_int8xint2_decode( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, + n_partition=4, + reduce_thread=32, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + storage_nbit = 8 + num_bits = 2 + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + C_shape = (M, N) + + num_elems_per_byte = 4 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + micro_size_k_compressed = micro_size_k // num_elems_per_byte + storage_dtype = T.int8 + block_K = reduce_thread * micro_size_k + + use_dp4a = True + dp4a_size = 4 + + @T.prim_func + def kernel( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel( + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), + ) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([micro_size_k], in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") + ni = T.thread_binding(0, n_partition, thread="threadIdx.y") + + T.import_source(decode_i2s_to_i8s) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k_compressed): + B_quant_local[v] = B[ + bx * n_partition + ni, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, + ] + + T.call_extern( + "handle", + "decode_i2u_to_i8s", + T.address_of(B_quant_local[0]), + T.address_of(B_dequantize_local[0]), + ) + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_dequantize_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki] * B_dequantize_local[ki] + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + ) + ) + if kr == 0: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return kernel + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) + + return int8_weight.view(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == T.int8 else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == T.int8: + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == T.float16: + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == T.float16: + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) + + +def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(program) + kernel = tilelang.compile(program) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + qw = general_compress(B.cpu().numpy(), source_bits=2, storage_dtype=np.int8) + qw = interleave_weight(qw, 2, target_dtype=in_dtype) + qw = torch.from_numpy(qw).to(device="cuda") + + kernel(A, qw, C) + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32) diff --git a/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py new file mode 100644 index 0000000000000000000000000000000000000000..8c337398233f32905bd4dd929490287d38660126 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -0,0 +1,385 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tilelang.intrinsics.mma_layout import ( + make_mma_swizzle_layout as make_swizzle_layout, +) +import numpy as np + +from tilelang.intrinsics.mma_macro_generator import ( + INT4TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func + +torch.manual_seed(42) + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + + +@simplify_prim_func +def bitnet_158_int8xint2_prefill( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=64, +): + """ + Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C. + + The returned prim_func expects: + - A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8). + - B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte). + - C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32). + + Details: + - Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter. + - Tiling parameters: + - block_row_warps, block_col_warps: number of warps per block in row/col. + - warp_row_tiles, warp_col_tiles: tiles per warp. + - chunk: K-sized chunk per block (block_K). + - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32). + - Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior. + - Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values. + + Parameters: + M, N, K (int): Global matrix dimensions. + in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8. + out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32. + accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32). + fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used). + block_row_warps (int): Warps in block row dimension. + block_col_warps (int): Warps in block column dimension. + warp_row_tiles (int): Tiles per warp in row dimension. + warp_col_tiles (int): Tiles per warp in column dimension. + chunk (int): K-length per block (block_K). + + Returns: + T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution. + """ + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == T.int32: + micro_size_k = 32 + + num_elems_per_byte = 4 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + shared_scope = "shared.dyn" + storage_dtype = T.int8 + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + """ + GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. + + This kernel: + - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. + - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. + - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. + - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. + + Parameters: + A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. + B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. + C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). + + Side effects: + Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. + """ + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i8s, + ) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) + + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_frag) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): + B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v + vi, vj = T.index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj] + + T.call_extern( + "handle", + "decode_i2u_to_i8s", + T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), + ) + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + thread_bindings * local_size + v + vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Perform STMatrix + mma_emitter.stmatrix( + C_frag, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) + + return int8_weight.view(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == T.int8 else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == T.int8: + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == T.float16: + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == T.float16: + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) + + +def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(program) + kernel = tilelang.compile(program) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + qw = general_compress(B.cpu().numpy(), source_bits=2, storage_dtype=np.int8) + qw = interleave_weight(qw, 2, target_dtype=in_dtype) + qw = torch.from_numpy(qw).to(device="cuda") + + kernel(A, qw, C) + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32) diff --git a/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d35df4b268e617f4d12e374b63dd51c7b3b071 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -0,0 +1,220 @@ +import torch +import torch.backends +from bitblas import tvm as tvm +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from bitblas.base import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + if in_dtype == T.int8: + A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + # assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32) diff --git a/tilelang/original/examples/bitnet-1.58b/load_from_quantized.py b/tilelang/original/examples/bitnet-1.58b/load_from_quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..8c775aa4c8e819ee3ac800fce4ebe0452fac54be --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/load_from_quantized.py @@ -0,0 +1,71 @@ +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +import os +from transformers import GenerationConfig +import time + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + # load quantized model + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = qmodel(input_ids) + print("original model output:", output) + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/maint/README.md b/tilelang/original/examples/bitnet-1.58b/maint/README.md new file mode 100644 index 0000000000000000000000000000000000000000..63cc3e275f18b8bec8e96eabc49c1e812218aee3 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/maint/README.md @@ -0,0 +1,91 @@ +--- +license: mit +--- + + +This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. + +## Latest News + +- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt and study how to deploy, please checkout [Make Checkpoints for vLLM](#make-checkpoints-for-vllm). + +## Make Checkpoints for vLLM + +We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. + +```bash +# move to the integration directory +cd /root/to/BitBLAS/integration/BitNet +# make the checkpoint +./maint/generate_bitnet_model_native_format.sh +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory +``` + +The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. + +```bash +./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory +``` + +Finnaly, you can use the ckpt in vLLM with: + +```bash +cd vllm_workspace +# inference with the ckpt with fp16 uncompressed metadata +python3 inference_with_native_format.py +# inference with the ckpt with BitBLAS compressed metadata +python3 inference_with_bitblas_format.py +``` + +## BitBLAS Results + +### Performance + +**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo. + +| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas | +|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 | +| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 | +| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 | + +### On-the-Fly GPU Memory Footprint + +We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage. + +| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** | +|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB | +| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB | +| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB | + +## PPL and Zero-shot Accuracy + +The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`. + +PPL and zero-shot accuracy: +| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg +|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 | +| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 | +| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 | +| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2 +| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 | +| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9 +| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7 +| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2 +| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 | + +The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors. + +## Citations + +```bibtex +@article{ma2024era, + title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits}, + author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu}, + journal={arXiv preprint arXiv:2402.17764}, + year={2024} +} +``` \ No newline at end of file diff --git a/tilelang/original/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py b/tilelang/original/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..2604ef38770fa58fa80cf87709e0b205eae26ecd --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py @@ -0,0 +1,130 @@ +import argparse +import torch +import bitblas +from transformers.utils.hub import cached_file +import os +from transformers import GenerationConfig +import time +import json + +import sys + +sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../") +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B") +parser.add_argument("--saved_model_path", type=str, default=None) +args = parser.parse_args() + +model_name_or_path = args.model_name_or_path +saved_model_path = ( + os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +) + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + model = ( + BitnetForCausalLM.from_pretrained( + model_name_or_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) + + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = model(input_ids) + print("original model output:", output) + + model.quantize(fuse_qkv=True, fuse_gateup=True) + print("original model generated text:") + print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + + model.save_pretrained(saved_model_path) + + # load quant config + quant_config_path = cached_file(model_name_or_path, "quantize_config.json") + with open(quant_config_path, "r") as f: + quant_config = json.load(f) + print("quant config:") + print(quant_config) + quant_config["checkpoint_format"] = "bitblas" + quant_config["fuse_qkv"] = True + quant_config["fuse_gateup"] = True + + # save quant config + quant_config_path = os.path.join(saved_model_path, "quantize_config.json") + with open(quant_config_path, "w") as f: + json.dump(quant_config, f) + print("quant config saved to:", quant_config_path) + + # copy benchmark filed into saved model path + file_list = [ + "configuration_bitnet.py", + "eval_utils.py", + "modeling_bitnet.py", + "tokenization_bitnet.py", + "utils_quant.py", + "README.md", + ] + for file in file_list: + file_path = cached_file(model_name_or_path, file) + os.system(f"cp {file_path} {saved_model_path}") + # load quantized model + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh b/tilelang/original/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh new file mode 100755 index 0000000000000000000000000000000000000000..741c3a124a54bcf4206104b2034771de93a30aea --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh @@ -0,0 +1,31 @@ +# retrieve the native model input and saved model directory +MODEL_DIR=$1 +SAVED_MODEL_DIR=$2 + +# check if the model directory exists +if [ ! -d "$MODEL_DIR" ]; then + echo "Model directory does not exist!" + exit 1 +fi + +# if the saved model directory does not exist, create it +# if SAVED_MODEL_DIR is not provided, we do not pass it to the script +if [ -z "$SAVED_MODEL_DIR" ]; then + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR +else + if [ ! -d "$SAVED_MODEL_DIR" ]; then + mkdir -p $SAVED_MODEL_DIR + fi + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR +fi + +# get the realpath of the saved model directory +SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR) + +# cp files +cp $MODEL_DIR/quantize_config.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.model $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer_config.json $SAVED_MODEL_DIR/ + +echo "Model has been converted and save to $SAVED_MODEL_DIR" diff --git a/tilelang/original/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh b/tilelang/original/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh new file mode 100755 index 0000000000000000000000000000000000000000..a2df0eb8cb2e057b751e572e1aa58c2532aece27 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh @@ -0,0 +1,25 @@ +# require git lfs +if ! command -v git-lfs &> /dev/null; then + echo "Please install git-lfs first by running 'sudo apt install git-lfs'" + exit 1 +fi + +mkdir -p models + +cd models + +# download the model +git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B ckpt_bitnet_b1_58-3B --depth 1 + +# copy quantized config into the model directory +cp ../maint/quantize_config.json ckpt_bitnet_b1_58-3B + +# copy README.md into the model directory +cp ../maint/README.md ckpt_bitnet_b1_58-3B + +# get the realpath of the model directory +MODEL_DIR=$(realpath ckpt_bitnet_b1_58-3B) + +cd .. + +echo "Model has been converted and save to $MODEL_DIR" diff --git a/tilelang/original/examples/bitnet-1.58b/maint/quantize_config.json b/tilelang/original/examples/bitnet-1.58b/maint/quantize_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e2b24123a125ebaf3c4b056e8e6546801fbac4dc --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/maint/quantize_config.json @@ -0,0 +1,10 @@ +{ + "bits": 2, + "desc_act": false, + "static_groups": false, + "sym": true, + "lm_head": false, + "model_name_or_path": "1bitLLM/bitnet_b1_58-3B", + "quant_method": "bitnet", + "checkpoint_format": "bitnet" +} \ No newline at end of file diff --git a/tilelang/original/examples/bitnet-1.58b/maint/upload_models.sh b/tilelang/original/examples/bitnet-1.58b/maint/upload_models.sh new file mode 100755 index 0000000000000000000000000000000000000000..b764b0da67a9b69d66a0e2a430356751de9df1e1 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/maint/upload_models.sh @@ -0,0 +1,34 @@ +MODEL_DIR=$1 +REMOTE_DIR=$2 + +if [ ! -d "$MODEL_DIR" ]; then + echo "Model directory does not exist!" + exit 1 +fi + +cd $MODEL_DIR +if [ ! -d ".git" ]; then + rm -rf .git +fi + +git init + +git checkout -b main + +git lfs install + +git lfs track *.bin + +git lfs track *.safetensors + +git add . + +git commit -m "Initial commit" + +git remote add origin $REMOTE_DIR + +huggingface-cli lfs-enable-largefiles . + +git fetch origin + +git push -f --set-upstream origin main diff --git a/tilelang/original/examples/bitnet-1.58b/modeling_bitnet.py b/tilelang/original/examples/bitnet-1.58b/modeling_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1830995ee6177536089fe517646b290c18bb05f2 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/modeling_bitnet.py @@ -0,0 +1,1686 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMA model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from configuration_bitnet import BitnetConfig +from utils_quant import BitLinear, BitLinearBitBLAS +from transformers.utils.hub import cached_file + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 + + +def find_layers(module, layers=None, name=""): + if not layers: + layers = [nn.Linear] + for layer in layers: + if isinstance(module, layer): + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + return res + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BitnetConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class BitnetRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BitnetRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(BitnetRMSNorm) + + +class BitnetRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class BitnetMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = BitLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.up_proj = BitLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.down_proj = BitLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) + + def forward(self, x): + x = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + x = self.ffn_layernorm(x) + x = self.down_proj(x) + return x + + +class BitnetMLPFuseGateUp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = BitLinear( + self.hidden_size, + self.intermediate_size * 2, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.down_proj = BitLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) + + @classmethod + def from_bit_mlp(cls, bit_mlp: BitnetMLP): + module = cls(bit_mlp.config) + # assign the weights + module.gate_up_proj.weight = nn.Parameter(torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.down_proj = bit_mlp.down_proj + module.ffn_layernorm = bit_mlp.ffn_layernorm + return module + + def forward(self, x): + gate_up = self.gate_up_proj(x) + gate, up = torch.chunk(gate_up, chunks=2, dim=-1) + x = self.act_fn(gate) * up + x = self.ffn_layernorm(x) + x = self.down_proj(x) + return x + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class BitnetAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) + + self.q_proj = BitLinear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.k_proj = BitLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.v_proj = BitLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.o_proj = BitLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self._init_rope() + self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BitnetRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise NotImplementedError + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BitnetAttentionQKVFused(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) + + self.qkv_proj = BitLinear( + self.hidden_size, + self.num_heads * self.head_dim + (self.num_key_value_heads * self.head_dim) * 2, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.o_proj = BitLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self._init_rope() + self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BitnetRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise NotImplementedError + + @classmethod + def from_bit_attention(cls, bit_attention: BitnetAttention): + module = cls(bit_attention.config, bit_attention.layer_idx) + # assign the weights + module.qkv_proj.weight = nn.Parameter( + torch.cat([bit_attention.q_proj.weight, bit_attention.k_proj.weight, bit_attention.v_proj.weight], dim=0) + ) + if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: + module.qkv_proj.bias = nn.Parameter( + torch.cat([bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias], dim=0) + ) + module.o_proj = bit_attention.o_proj + module.inner_attn_ln = bit_attention.inner_attn_ln + if bit_attention.config.rope_scaling is None: + module.rotary_emb = bit_attention.rotary_emb + return module + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split( + qkv_states, + [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], + dim=-1, + ) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BitnetFlashAttention2(BitnetAttention): + """ + Bitnet flash attention module. This module inherits from `BitnetAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (BitnetRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in BitnetFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +LLAMA_ATTENTION_CLASSES = { + "eager": BitnetAttention, + "flash_attention_2": BitnetFlashAttention2, +} + + +class BitnetDecoderLayer(nn.Module): + def __init__(self, config: BitnetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = BitnetMLP(config) + self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", + stacklevel=2, + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BitnetConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class BitnetPreTrainedModel(PreTrainedModel): + config_class = BitnetConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BitnetDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + for layer in self.model.layers: + device = layer.input_layernorm.weight.device + if hasattr(self.config, "_pre_quantization_dtype"): + dtype = self.config._pre_quantization_dtype + else: + dtype = layer.self_attn.o_proj.weight.dtype + layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) + + def _reset_cache(self): + for layer in self.model.layers: + layer.self_attn.past_key_value = None + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`BitnetTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`BitnetTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class BitnetModel(BitnetPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BitnetDecoderLayer`] + + Args: + config: BitnetConfig + """ + + def __init__(self, config: BitnetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if use_cache and not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = mask_slice + + return causal_mask + + +class BitnetForCausalLM(BitnetPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = BitnetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.quantized = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import LlamaTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Bitnet-2-7b-hf") + >>> tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Bitnet-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None + + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length: + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids") + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) + return reordered_past + + @staticmethod + def recursive_set(model, name, attr): + """ + set layers.25.mlp.up_proj to attr + """ + + names = name.split(".") + obj = model + for n in names[:-1]: + obj = getattr(obj, n) + setattr(obj, names[-1], attr) + + def quantize(self, fuse_qkv=True, fuse_gateup=True): + for name, module in self.model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + self.recursive_set(self.model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + self.recursive_set(self.model, name, bitnet_mlp_fused) + for name, module in self.model.named_modules(): + # if is bitnet layer + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + if name.endswith(".qkv_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=3) + elif name.endswith(".gate_up_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=2) + else: + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + print("Replacing module", name, "with a quantized version") + self.recursive_set(self.model, name, bitblas_linear) + self.quantized = True + + def _post_process_weights(self): + for name, module in self.model.named_modules(): + if hasattr(module, "post_process_weights"): + print("Post processing weights for module", name) + module.post_process_weights() + + def _replace_weight_param_with_qweight(self): + for name, module in self.model.named_modules(): + if hasattr(module, "replace_weight_param_with_qweight"): + print("Replacing weight param with qweight for module", name) + module.replace_weight_param_with_qweight() + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + # == step1: prepare configs and file names == # + config: BitnetConfig = BitnetConfig.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + **cached_file_kwargs, + ) + # load quantize config + quantize_file = cached_file(model_name_or_path, "quantize_config.json") + assert quantize_file is not None, "quantize config file not found" + import json + + # get quantize format + with open(quantize_file, "r") as f: + quant_config = json.load(f) + checkpoint_format = quant_config["checkpoint_format"] + assert checkpoint_format in ["bitblas"], "quantize format not supported" + fuse_qkv = quant_config.get("fuse_qkv", True) + fuse_gateup = quant_config.get("fuse_gateup", True) + + import accelerate + + if checkpoint_format == "bitblas": + model = cls(config) + for name, module in model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + model.recursive_set(model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + model.recursive_set(model, name, bitnet_mlp_fused) + for name, module in model.named_modules(): + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + print("Replacing module", name, "with a quantized version") + model.recursive_set(model, name, bitblas_linear) + accelerate.utils.modeling.load_checkpoint_in_model( + model, + checkpoint=model_name_or_path, + offload_state_dict=True, + offload_buffers=True, + ) + return model + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`BitnetForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class BitnetForSequenceClassification(BitnetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = BitnetModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Bitnet Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class BitnetForQuestionAnswering(BitnetPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Bitnet + def __init__(self, config): + super().__init__(config) + self.transformer = BitnetModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labeled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labeled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tilelang/original/examples/bitnet-1.58b/nvidia_measure_memory.sh b/tilelang/original/examples/bitnet-1.58b/nvidia_measure_memory.sh new file mode 100755 index 0000000000000000000000000000000000000000..e8998f3092bc4a7ea9d3539a7625169365133488 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/nvidia_measure_memory.sh @@ -0,0 +1 @@ +nvidia-smi --query-gpu=memory.used --format=csv -lms 500 diff --git a/tilelang/original/examples/bitnet-1.58b/requirements.txt b/tilelang/original/examples/bitnet-1.58b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..67357781e0a2afd5bd550329ec5c756f09f4b6b6 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/requirements.txt @@ -0,0 +1,3 @@ +lm_eval==0.3.0 +flash_attn +transformers==4.53.0 diff --git a/tilelang/original/examples/bitnet-1.58b/tokenization_bitnet.py b/tilelang/original/examples/bitnet-1.58b/tokenization_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2adfd6dee10e6d0fba443e14c7b828e73b378554 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/tokenization_bitnet.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for LLaMA.""" + +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class BitnetTokenizer(PreTrainedTokenizer): + """ + Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Bitnet should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behavior of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id + + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + + The reference for this chat template is [this code + snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) + in the original repository. + """ + logger.warning_once( + "\nNo chat template is defined for this tokenizer - using the default template " + f"for the {self.__class__.__name__} class. If the default is not appropriate for " + "your model, please set `tokenizer.chat_template` to an appropriate template. " + "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" + ) + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/tilelang/original/examples/bitnet-1.58b/utils_quant.py b/tilelang/original/examples/bitnet-1.58b/utils_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..5a50edb392ead6d55c9e34f19409cfb94848f13a --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/utils_quant.py @@ -0,0 +1,230 @@ +# pylint: disable=missing-docstring, invalid-name +"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py to work with BitBLAS.""" + +import torch +from torch import nn +from bitblas.cache import global_operator_cache, get_database_path +from bitblas import Matmul, MatmulConfig +from bitblas import auto_detect_nvidia_target +from logging import getLogger + +logger = getLogger(__name__) +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = get_database_path() + + +def weight_quant(weight, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + +def activation_quant(x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) + + +class BitLinearBitBLAS(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + weight_bits=1, + input_bits=8, + **kwargs, + ): + super().__init__() + """ + RMSNorm is placed outside BitLinear + """ + self.in_features = in_features + self.out_features = out_features + self.weight_bits = weight_bits + self.input_bits = input_bits + matmul_config = MatmulConfig( + N=self.out_features, # N dimension + K=self.in_features, # K dimension + A_dtype="int8", # activation A dtype + W_dtype="int2", # weight W dtype + accum_dtype="int32", # accumulation dtype + out_dtype="float32", # output dtype + layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose + with_bias=False, # bias + # configs for weight only quantization + group_size=None, # setting for grouped quantization + with_scaling=False, # setting for scaling factor + with_zeros=False, # setting for zeros + zeros_mode=None, # setting for how to calculating zeros + ) + ENABLE_TUNING = True + self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) + + self.format = "bitnet" + self.Qp = 2 ** (self.input_bits - 1) - 1 + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info(f"Loaded {global_operator_cache.size()} operators from database.") + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + # should disable tuning for the first time because we may require loading bitblas operator from database. + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + print("BitBLAS Tuning done, appended operator to global_operator_cache.") + else: + print("BitBLAS Operator created.") + else: + print("BitBLAS Operator found in global_operator_cache.") + return bitblas_matmul + + def replace_weight_param_with_qweight(self): + if hasattr(self, "weight"): + del self.weight + quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape()) + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @classmethod + def from_bit_linear(cls, bitlinear, weight_group=1): + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) + bitblas_linear.register_buffer("qweight", qweight) + bitblas_linear.register_buffer("sw", sw) + if bitlinear.bias is not None: + bitblas_linear.register_buffer("bias", bitlinear.bias) + else: + bitblas_linear.bias = None + return bitblas_linear + + def create_bitblas_weights(self, weight, weight_group=1): + if weight_group: + hidden_size = weight.size(0) + group_size = hidden_size // weight_group + + sw_list = [] + qweight_list = [] + + for i in range(weight_group): + start_idx = i * group_size + end_idx = (i + 1) * group_size + + sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5) + sw_list.append(sw.repeat(group_size)) + + qweight = self.weight_quant(weight[start_idx:end_idx]).detach() + qweight_list.append(qweight) + + sw = torch.cat(sw_list, dim=0) + qweight = torch.cat(qweight_list, dim=0) + else: + sw = 1 / weight.abs().mean().clamp(min=1e-5) + qweight = self.weight_quant(weight).detach() + qweight = self.bitblas_matmul.transform_weight(qweight) + qweight = nn.Parameter(qweight, requires_grad=False) + return sw, qweight + + def post_process_weights(self): + sw = 1 / self.weight.abs().mean().clamp(min=1e-5) + self.sw = sw + quant_weight = self.weight_quant(self.weight).detach() + quant_weight = self.bitblas_matmul.transform_weight(quant_weight) + # remove self.weight and replace it with quant_weight + if hasattr(self, "weight"): + del self.weight + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @staticmethod + def weight_quant(weight): + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) + return result.type(torch.int8) + + @torch.compile + def activation_quant(self, x, num_bits=8): + x = x.float() + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) + return result.type(torch.int8), s + + @torch.compile + def post_quant_process(self, input, si, sw): + out = input / si + out = out / sw + out = out.half() + return out + + # for the correctness evaluation. + def native_forward(self, input): + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + return out + + def forward_fp32_simulated(self, input): + quant_input, si = self.activation_quant(input, self.input_bits).detach() + quant_weight = self.weight_quant(self.weight).detach() + + fp32_simulated_input = quant_input.float() + fp32_simulated_weight = quant_weight.float() + fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight) + + sw = 1 / self.weight.abs().mean().clamp(min=1e-5) + # if / (si * sw) it will inf in some cases + out = fp32_simulated_out / si + out = out / sw + out = out.half() + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + return out + + def forward(self, input): + # return self.forward_fp32_simulated(input) + quant_input, si = self.activation_quant(input, self.input_bits) + fp32_out = self.bitblas_matmul(quant_input, self.qweight) + sw = self.sw + # if / (si * sw) it will inf in some cases + out = self.post_quant_process(fp32_out, si, sw) + + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + return out + + +# Naive BitLinear from HuggingFace +class BitLinear(nn.Linear): + def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): + super(BitLinear, self).__init__(*kargs, **kwargs) + """ + RMSNorm is placed outside BitLinear + """ + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, input): + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + + return out diff --git a/tilelang/original/examples/bitnet-1.58b/vllm_workspace/conftest.py b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e2997ef67c5c22b26235d00000332dfe20910f --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/conftest.py @@ -0,0 +1,587 @@ +import contextlib +import gc +import os +import sys +from collections import UserList +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoTokenizer, + BatchEncoding, +) + +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset +from vllm.config import TokenizerPoolConfig +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel +from vllm.inputs import TextPrompt +from vllm.logger import init_logger +from vllm.sequence import SampleLogprobs +from vllm.utils import cuda_device_count_stateless, is_cpu + +logger = init_logger(__name__) + +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] +_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] + + +def _read_prompts(filename: str) -> List[str]: + with open(filename, "r") as f: + prompts = f.readlines() + return prompts + + +class _ImageAssetPrompts(TypedDict): + stop_sign: str + cherry_blossom: str + + +if sys.version_info < (3, 9): + # UserList cannot be subscripted + class _ImageAssetsBase(UserList): + pass + +else: + + class _ImageAssetsBase(UserList[ImageAsset]): + pass + + +class _ImageAssets(_ImageAssetsBase): + def __init__(self) -> None: + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) + + def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: + """ + Convenience method to define the prompt for each test image. + + The order of the returned prompts matches the order of the + assets when iterating through this object. + """ + return [prompts["stop_sign"], prompts["cherry_blossom"]] + + +IMAGE_ASSETS = _ImageAssets() +"""Singleton instance of :class:`_ImageAssets`.""" + + +def cleanup(): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + if not request.node.get_closest_marker("skip_global_cleanup"): + return False + + +@pytest.fixture(autouse=True) +def cleanup_fixture(should_do_global_cleanup_after_test: bool): + yield + if should_do_global_cleanup_after_test: + cleanup() + + +@pytest.fixture +def example_prompts() -> List[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture +def example_long_prompts() -> List[str]: + prompts = [] + for filename in _LONG_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture(scope="session") +def image_assets() -> _ImageAssets: + return IMAGE_ASSETS + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, +} + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) + + +class HfRunner: + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + is_embedding_model: bool = False, + is_vision_model: bool = False, + is_sparseml_model: bool = False, + ) -> None: + assert dtype in _STR_DTYPE_TO_TORCH_DTYPE + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + + if is_embedding_model: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype) + ) + else: + if is_vision_model: + auto_cls = AutoModelForVision2Seq + elif is_sparseml_model: + from sparseml.transformers import SparseAutoModelForCausalLM + + auto_cls = SparseAutoModelForCausalLM + else: + auto_cls = AutoModelForCausalLM + + model_kwargs = model_kwargs if model_kwargs is not None else {} + self.model = self.wrap_device( + auto_cls.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs, + ) + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.", + model_name, + ) + self.processor = self.tokenizer + + def generate( + self, + prompts: List[str], + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images: + assert len(prompts) == len(images) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output_ids = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + **kwargs, + ) + output_str = self.processor.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output_ids = output_ids.cpu().tolist() + outputs.append((output_ids, output_str)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + **kwargs, + ) + + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + ) + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + for j in range(len(output_ids)): + output_ids[j] = [x for x in output_ids[j] if x != self.tokenizer.pad_token_id] + outputs[i] = (output_ids, output_str) + return outputs + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[List[torch.Tensor]]: + all_logprobs: List[List[torch.Tensor]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + seq_logprobs: List[torch.Tensor] = [] + for hidden_states in output.hidden_states: + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + all_logprobs.append(seq_logprobs) + return all_logprobs + + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + input_ids = inputs.input_ids + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + + seq_logprobs: List[torch.Tensor] = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", None) is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner + + +class VllmRunner: + def __init__( + self, + model_name: str, + tokenizer_name: Optional[str] = None, + # Use smaller max model length, otherwise bigger model cannot run due + # to kv cache size limit. + max_model_len: int = 1024, + dtype: str = "half", + disable_log_stats: bool = True, + tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = False, + swap_space: int = 4, + enforce_eager: bool = False, + **kwargs, + ) -> None: + self.model = LLM( + model=model_name, + tokenizer=tokenizer_name, + trust_remote_code=True, + dtype=dtype, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for req_output in req_outputs: + prompt_str = req_output.prompt + prompt_ids = req_output.prompt_token_ids + req_sample_output_ids: List[List[int]] = [] + req_sample_output_strs: List[str] = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append(prompt_str + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) + return outputs + + def generate_w_logprobs( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + assert sampling_params.logprobs is not None + + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str]]: + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs = self.generate(prompts, greedy_params, images=images) + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) + + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + beam_search_params = SamplingParams( + n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ) + outputs = self.generate(prompts, beam_search_params) + return outputs + + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig(pool_size=1, pool_type="ray", extra_config={}) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + +@pytest.fixture(scope="session") +def num_gpus_available(): + """Get number of GPUs without initializing the CUDA context + in current process.""" + + return cuda_device_count_stateless() diff --git a/tilelang/original/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18239cbc8fc00aaf65297a77fd5db0bf27e6ac --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py @@ -0,0 +1,45 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) + +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas") +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) + print("bitnet inference:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/tilelang/original/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py new file mode 100644 index 0000000000000000000000000000000000000000..f631fb306772408b17d71c35a5ae8bc1084e10d9 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py @@ -0,0 +1,47 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B") + +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path + +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) + print("bitnet inference output:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/tilelang/original/examples/bitnet-1.58b/vllm_workspace/utils.py b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e96b19e28ca9e21af070bdd187e4b026aca26bc7 --- /dev/null +++ b/tilelang/original/examples/bitnet-1.58b/vllm_workspace/utils.py @@ -0,0 +1,45 @@ +from typing import Dict, List, Tuple + +TokensText = Tuple[List[int], str] + + +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str): + """ + Compare the two sequences generated by different models, + which should be equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + + +TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] + + +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): + """ + Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + + # Break out since sequences will now diverge. + break diff --git a/tilelang/original/examples/blocksparse_attention/README.md b/tilelang/original/examples/blocksparse_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..89f75b81de950a1139c78c73616d5689afac6b49 --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/README.md @@ -0,0 +1,6 @@ +# Block-Sparse Flash-Attention + +Tilelang implementation of block-sparse flash-attention kernels. + +The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889). + diff --git a/tilelang/original/examples/blocksparse_attention/block_sparse_attn_triton.py b/tilelang/original/examples/blocksparse_attention/block_sparse_attn_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..1794836342197de8c16bfa2eb515e872c94c663b --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -0,0 +1,361 @@ +# ruff: noqa: E712 +import math +import torch + +import triton +import triton.language as tl +import torch.nn.functional as F + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + seqlen_k, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + # print + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + N_CTX, + PAST_LEN, + col_idx == k_block_end - 1, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + print("downsample_len", downsample_len) + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + print("x_ds.shape", x_ds.shape) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + # print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + + # Run Triton kernel + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # print("ref_output", ref_output) + # print("triton_output", triton_output) + + # Verify accuracy + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlt_kl(): + BATCH, N_HEADS = 2, 4 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) + + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + # softmax scale + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + print("downsample_factor", downsample_factor) + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + print("downsample_len", downsample_len) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + # Run Triton kernel. + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) + + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # Verify accuracy. + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" + + print("Pass topk sparse attention test with qlen < klen") + + +def main(): + test_topk_sparse_attention() + test_topk_sparse_attention_qlt_kl() + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/tilelang/original/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..934b0b25efaac9568dac2b398d274321a803b54a --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -0,0 +1,221 @@ +import math +import torch + +import tilelang +import tilelang.language as T +import torch.nn.functional as F + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@tilelang.jit( + out_idx=[4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 1 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def blocksparse_flashattn( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return blocksparse_flashattn + + return kernel_func(block_M, block_N, num_stages, threads) + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + # Run tilelang kernel + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + + tilelang_output = kernel(q, k, v, block_mask) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + # Verify accuracy + torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2) + print("Pass topk sparse attention test with qlen == klen") + + +def main(): + test_topk_sparse_attention() + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..77a29ebe284ef7df8265687bca1217166475739d --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -0,0 +1,551 @@ +# ruff: noqa +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse +import time +import math + +from heuristic import num_splits_heuristic + + +def flashattn(batch, heads, heads_kv, dim, dim_v): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // heads_kv + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def kernel_func( + block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + ): + shape_q = [batch, heads, dim] + shape_k = [num_pages, page_block_size, heads_kv, dim] + shape_v = [num_pages, page_block_size, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_block_table = [batch, max_num_blocks_per_seq] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + assert block_N <= page_block_size and page_block_size % block_N == 0 + block_ratio = page_block_size // block_N + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var("bool") + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + logical_block_idx = block_indices[bid, cur_kv_head, start + k] + if logical_block_idx >= 0: + has_valid_block = True + block_table_idx = T.floordiv(logical_block_idx, block_ratio) + block_tile_idx = T.floormod(logical_block_idx, block_ratio) + physical_block_idx = block_table[bid, block_table_idx] + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else( + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] /= logsum[i] + + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + max_split = T.alloc_local([1], T.int32) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split[0] = glse[bz, by, k] + if lse_local_split[0] != 0: + max_split[0] = k + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split[0]: + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + if k <= max_split[0]: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + return kernel_func + + +class SparseFlashAttn(torch.nn.Module): + def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): + super(SparseFlashAttn, self).__init__() + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.dim = dim + self.dim_v = dim_v + self.block_N = block_N + self.page_block_size = page_block_size + self.num_pages = num_pages + self.block_H = 64 + + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_N, + block_H=self.block_H, + page_block_size=page_block_size, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + num_pages=num_pages, + max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"), + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) + + props = torch.cuda.get_device_properties(torch.device("cuda:0")) + self.num_sm = props.multi_processor_count + + def forward(self, query, key, value, block_indices, cache_seqlens, block_table): + batch = self.batch + heads = self.heads + heads_kv = self.heads_kv + dim_v = self.dim_v + dim = self.dim + block_size = self.block_N + max_selected_blocks = block_indices.shape[-1] + + # Compute static scheduling parameters + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = self.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = self.kernel( + query, + key, + value, + block_indices, + cache_seqlens, + block_table, + glse, + output_partial, + ) + return output + + +def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size): + """ + Paged version of sparse attention reference implementation. + + Args: + query: [batch, heads, dim] + key_cache: [num_pages, page_block_size, heads_kv, dim] + value_cache: [num_pages, page_block_size, heads_kv, dim] + block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices + cache_seqlens: [batch] - actual sequence lengths + block_table: [batch, max_num_blocks_per_seq] - maps logical to physical blocks + page_block_size: size of each page block + block_size: size of attention blocks (block_N) + """ + batch, heads, dim = query.shape + heads_kv = key_cache.shape[2] + dim_v = value_cache.shape[3] + num_head_groups = heads // heads_kv + scale = dim**0.5 + + # Reconstruct the full key and value tensors from paged cache + max_cache_seqlen = max(cache_seqlens).item() + key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device) + value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device) + + # Reconstruct full tensors from paged cache using block_table + for b in range(batch): + seq_len = cache_seqlens[b].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + + for block_idx in range(num_blocks_needed): + physical_block_idx = block_table[b, block_idx].item() + + # Calculate the range of tokens for this block + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + + # Copy from paged cache to full tensors + key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + + # Reshape query for grouped attention + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + # Compute attention scores + scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + # Create sparse mask based on block_indices + sparse_mask = torch.zeros_like(scores) + + # Apply sparse mask based on selected blocks + for b in range(batch): + for h in range(heads_kv): + valid_indices = block_indices[b, h] # Extract indices for this batch and head + for idx in valid_indices: + if idx >= 0: # Valid block index + start_pos = idx * block_size + end_pos = min(start_pos + block_size, max_cache_seqlen) + sparse_mask[b, :, h, start_pos:end_pos] = 1 + + # Apply sparse mask + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + # Apply causal mask based on actual sequence lengths + range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + + # Compute attention weights + attention = F.softmax(scores / scale, dim=-1) + + # Apply attention to values + out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + + # Reshape output back to original format + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + + return out + + +def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) + output = output.squeeze(1) + return output + + +def main(args): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) + sparse_ratio = args.sparse_ratio + block_N = args.block_N + page_block_size = args.page_block_size + num_blocks = args.num_pages # Use num_pages from args + + # For dense case verification, set sparse_ratio to 0 to select all blocks + max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + + # Generate random inputs + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") + print("cache_seqlens: ", cache_seqlens) + + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + + # Create paged KV cache + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + + # Create block table and block indices for dense case (all blocks selected) + max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) + print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") + + # Fill block table and block indices and cache + + # Create a pool of available physical blocks + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + available_blocks = list(range(total_blocks_needed)) + import random + + random.seed(42) # For reproducibility + random.shuffle(available_blocks) + + # Fill block table with random physical block indices + block_assignment = {} # Map (seq_idx, block_idx) -> physical_block_idx + block_idx_counter = 0 + + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + + # Assign random physical blocks for each sequence + for block_idx in range(num_blocks_needed): + physical_block_idx = available_blocks[block_idx_counter] + block_table[seq_idx, block_idx] = physical_block_idx + block_assignment[(seq_idx, block_idx)] = physical_block_idx + block_idx_counter += 1 + + print(f"Block table: {block_table}") + + # Fill K_cache and V_cache with data from original K and V tensors using random block assignment + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + + for block_idx in range(num_blocks_needed): + physical_block_idx = block_assignment[(seq_idx, block_idx)] + + # Calculate the range of tokens for this block + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + + # Copy K and V data to the paged cache + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] + + # Fill block_indices for sparse attention + # For dense case (verification), we select all blocks in reverse order + # For sparse case, we select a subset of blocks based on sparse_ratio + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_tile = int(math.ceil(seq_len / block_N)) + + if sparse_ratio == 0.0: + # Dense case: select all blocks in reverse order + selected_blocks = min(num_tile, max_selected_blocks) + for head_idx in range(heads_kv): + for i in range(selected_blocks): + # Select blocks in reverse order (most recent first) + block_indices[seq_idx, head_idx, i] = num_tile - 1 - i + # Fill remaining slots with -1 (invalid) + for i in range(selected_blocks, max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + else: + # Fill block_indices for all KV heads + num_selected = int(num_tile * (1.0 - sparse_ratio)) + num_selected = max(1, min(num_selected, max_selected_blocks)) + all_blocks = list(range(num_tile)) + for head_idx in range(heads_kv): + selected_blocks = [] + # Always include the most recent blocks + recent_blocks = 1 + selected_blocks.append(num_tile - 1) + + # Randomly select some earlier blocks + if num_selected > recent_blocks: + remaining_blocks = [b for b in all_blocks if b not in selected_blocks] + if remaining_blocks: + import random + + random.seed(42) # For reproducibility + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) + selected_blocks.extend(additional_blocks) + + # Sort selected blocks in reverse order (most recent first) + selected_blocks.sort(reverse=True) + + for i in range(len(selected_blocks)): + block_indices[seq_idx, head_idx, i] = selected_blocks[i] + # Fill remaining slots with -1 (invalid) + for i in range(len(selected_blocks), max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + + # Initialize sparse attention module + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) + + import flash_attn # noqa: F401 + + output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N) + + output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) + # Check correctness + if sparse_ratio == 0.0: + max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() + mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() + assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" + else: + max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() + mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() + + print(f"Max difference: {max_diff:.6f}") + print(f"Mean difference: {mean_diff:.6f}") + + if max_diff < 1e-2: + print("✓ Verification PASSED: Results match within tolerance") + else: + print("✗ Verification FAILED: Results differ significantly") + + # Performance measurement + for _ in range(10): # Warm-up + sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) + + torch.cuda.synchronize() + start_time = time.time() + for _ in range(100): # Run multiple times for averaging + sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) + torch.cuda.synchronize() + end_time = time.time() + + kernel_time = (end_time - start_time) / 100 * 1000 # Convert to ms + print(f"Kernel execution time: {kernel_time:.2f} ms") + + # FA performance measurement + for _ in range(10): # Warm-up + ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) + + torch.cuda.synchronize() + start_time_fa = time.time() + for _ in range(100): # Run multiple times for averaging + ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) + torch.cuda.synchronize() + end_time_fa = time.time() + kernel_time_fa = (end_time_fa - start_time_fa) / 100 * 1000 # Convert to ms + print(f"FA kernel execution time: {kernel_time_fa:.2f} ms") + + print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio") + parser.add_argument("--block_N", type=int, default=64, help="block_N") + parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages") + parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages") + args = parser.parse_args() + main(args) diff --git a/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py new file mode 100644 index 0000000000000000000000000000000000000000..257f41543c3fc2f9d4e044d4ef9a4283edf01142 --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -0,0 +1,435 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from einops import rearrange, einsum +import argparse +import time +import math +from heuristic import num_splits_heuristic + + +def flashattn(batch, heads, heads_kv, dim, dim_v): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // heads_kv + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var("bool") + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + + for k in T.Pipelined(loop_range, num_stages=num_stages): + i_s = block_indices[bid, cur_kv_head, start + k] + if i_s >= 0: + has_valid_block = True + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] /= logsum[i] + + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + max_split = T.alloc_local([1], T.int32) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split[0] = glse[bz, by, k] + if lse_local_split[0] != 0: + max_split[0] = k + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split[0]: + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + if k <= max_split[0]: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) + flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + return kernel_func + + +class SparseFlashAttn(torch.nn.Module): + def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): + super(SparseFlashAttn, self).__init__() + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.dim = dim + self.dim_v = dim_v + self.block_size = block_size + + self.block_H = 64 + + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=self.block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) + + props = torch.cuda.get_device_properties(torch.device("cuda:0")) + self.num_sm = props.multi_processor_count + + def forward(self, query, key, value, block_indices, cache_seqlens): + batch = self.batch + heads = self.heads + heads_kv = self.heads_kv + dim_v = self.dim_v + dim = self.dim + block_size = self.block_size + max_selected_blocks = block_indices.shape[-1] + + # Compute static scheduling parameters + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = self.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) + return output + + +def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size): + """ + Args: + query: [batch, heads, dim] + key: [batch, max_cache_seqlen, heads_kv, dim] + value: [batch, max_cache_seqlen, heads_kv, dim_v] + block_indices: [batch, heads_kv, max_selected_blocks], indices of selected blocks, -1 for padding + cache_seqlens: [batch], sequence lengths of the kvcache + max_cache_seqlen: maximum sequence length of kvcache + block_size: block size + Returns: + output: [batch, heads, dim_v] + + """ + + batch, heads, dim = query.shape + heads_kv = key.shape[2] + dim_v = value.shape[-1] + max_selected_blocks = block_indices.shape[-1] + block_H = 64 + + actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + + # get num_split + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size + # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 132 + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) + + output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) + return output + + +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values based on block_indices + for b in range(batch): + for h in range(heads_kv): + valid_indices = block_indices[b, h] # Extract indices for this batch and head + for idx in valid_indices: + if idx >= 0: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def debug(name, expect, actual, atol=1e-3, rtol=1e-3): + all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) + print(name + " all_close={}".format(all_close)) + if not all_close: + diff = (expect - actual).abs() + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) + max_indices = torch.nonzero(diff == diff.max().item()) + first_index = tuple(max_indices[0].tolist()) + print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") + + +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + # # Ensure at least one element equals cache_seqlen + # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index + # # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + + print("cache_seqlens: ", cache_seqlens) + + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_indices with -1 (for padding blocks) + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) + # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + if max_valid_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + # Sort indices within each batch-group for consistency + block_indices, _ = block_indices.sort(dim=-1, descending=True) + # print("block_indices: ", block_indices) + actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] + print("actual_num_blocks: ", actual_num_blocks) + # print(block_indices.shape, actual_num_blocks.shape) + + max_num_blocks = torch.max(max_valid_num_blocks).item() + print("max_num_blocks: ", max_num_blocks) + + # parity reference + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) + debug("output", ref, out, atol=1e-3, rtol=1e-3) + + import flash_attn # noqa: F401 + + ## latency reference + for _ in range(10): + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + torch.cuda.synchronize() + print("dense time: ", (time.time() - start) / 100 * 1000) + + for _ in range(10): + # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) + out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) + out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) + torch.cuda.synchronize() + print("sparse time: ", (time.time() - start) / 100 * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..2957f8c970986e4f5f48673a7026677a10dc2b17 --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -0,0 +1,420 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + +import time +import math +from heuristic import num_splits_heuristic + + +def flashattn(batch, heads, heads_kv, dim, dim_v): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // heads_kv + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_mask = [batch, heads_kv, num_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var("bool") + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[bid, hid, start + k]: + has_valid_block = True + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else( + (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + return kernel_func + + +class SparseFlashAttn(torch.nn.Module): + def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): + super(SparseFlashAttn, self).__init__() + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.dim = dim + self.dim_v = dim_v + self.block_size = block_size + + self.block_H = 64 + + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=self.block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + num_blocks=T.dynamic("num_blocks"), + ) + + props = torch.cuda.get_device_properties(torch.device("cuda:0")) + self.num_sm = props.multi_processor_count + + def forward(self, query, key, value, block_mask, cache_seqlens): + batch = self.batch + heads = self.heads + heads_kv = self.heads_kv + dim_v = self.dim_v + dim = self.dim + block_size = self.block_size + block_H = self.block_H + max_cache_seqlen = key.shape[1] + # get num_split + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + # num_sm = 132 + num_sm = self.num_sm + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + # print("num_split: ", num_split) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + return output + + +def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_size): + """ + Args: + query: [batch, heads, dim] + key: [batch, max_cache_seqlen, heads_kv, dim] + value: [batch, max_cache_seqlen, heads_kv, dim_v] + block_mask: [batch, heads_kv, num_blocks], mask for valid blocks + cache_seqlens: [batch], sequence lengths of the kvcache + block_size: block size + Returns: + output: [batch, heads, dim_v] + + """ + + batch, heads, dim = query.shape + heads_kv = key.shape[2] + dim_v = value.shape[-1] + block_H = 64 + + actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + max_selected_blocks = actual_num_blocks.max().item() + # get num_split + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size + # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 132 + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + num_blocks=T.dynamic("num_blocks"), + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + # print(kernel.get_kernel_source()) + + output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + + return output + + +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values + for b in range(batch): + for h in range(heads_kv): + for idx in range(num_blocks): + if block_mask[b, h, idx]: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def debug(name, expect, actual, atol=1e-3, rtol=1e-3): + all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) + print(name + " all_close={}".format(all_close)) + if not all_close: + # print(expect[3, 28]) + # print(actual[3, 28]) + diff = (expect - actual).abs() + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) + max_indices = torch.nonzero(diff == diff.max().item()) + first_index = tuple(max_indices[0].tolist()) + print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") + + +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # Ensure at least one element equals cache_seqlen + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + + print("cache_seqlens: ", cache_seqlens) + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + print("valid_num_blocks: ", valid_num_blocks) + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_mask with false (for padding blocks) + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch + if valid_num_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + # print("block_mask: ", block_mask) + + # parity reference + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) + model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + out = model(Q, K, V, block_mask, cache_seqlens) + debug("output", ref, out, atol=1e-3, rtol=1e-3) + + import flash_attn # noqa: F401 + + ## latency reference + for _ in range(10): + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + torch.cuda.synchronize() + print("dense time: ", (time.time() - start) / 100 * 1000) + + for _ in range(10): + # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) + out = model(Q, K, V, block_mask, cache_seqlens) + + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) + out = model(Q, K, V, block_mask, cache_seqlens) + torch.cuda.synchronize() + print("sparse time: ", (time.time() - start) / 100 * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/tilelang/original/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/tilelang/original/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py new file mode 100644 index 0000000000000000000000000000000000000000..b61d52fa092f4d8cd115905d71cde59a99ca88dc --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -0,0 +1,433 @@ +# ruff: noqa +import torch +import triton +import triton.language as tl +import argparse +from einops import rearrange, einsum +import torch.nn.functional as F + +import math +import time +from heuristic import num_splits_heuristic + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], +) +@triton.jit +def _split_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + cache_seqlens_ptr, + o_partial_ptr, + lse_partial_ptr, + mask_ptr, + sm_scale, + num_splits, + gqa_group_size, + max_selected_blocks, + stride_q_b, + stride_q_h, + stride_q_d, + stride_k_b, + stride_k_s, + stride_k_h, + stride_k_d, + stride_v_b, + stride_v_s, + stride_v_h, + stride_v_d, + stride_o_b, + stride_o_h, + stride_o_split, + stride_o_d, + stride_lse_b, + stride_lse_h, + stride_lse_split, + stride_mask_b, + stride_mask_h, + stride_mask_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx_kv = tl.program_id(1) + split_idx = tl.program_id(2) + + head_idx_q = head_idx_kv * gqa_group_size + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx) + num_blocks = max_selected_blocks + blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32) + remaining_blocks = num_blocks % num_splits + if split_idx < remaining_blocks: + loop_range = blocks_per_split + 1 + else: + loop_range = blocks_per_split + + q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d + mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h + + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) + start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) + for i in range(loop_range): + block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) + if block_idx >= 0: + start_n = block_idx * BLOCK_N + k_ptr = k_cache_ptr + start_n * stride_k_s + v_ptr = v_cache_ptr + start_n * stride_v_s + + k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0) + v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0) + + qk = tl.dot(q, k) + qk = qk * sm_scale + qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + m_i = m_ij + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(o_partial_ptr.dtype.element_ty) + + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) + + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) + tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], +) +@triton.jit +def _merge_kernel( + o_partial_ptr, + lse_partial_ptr, + o_ptr, + lse_partial_stride_b, + lse_partial_stride_h, + lse_partial_stride_split, + o_partial_stride_b, + o_partial_stride_h, + o_partial_stride_split, + o_partial_stride_d, + o_stride_b, + o_stride_h, + o_stride_d, + BLOCK_D: tl.constexpr, + num_splits: tl.constexpr, + num_splits_pow2: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + offs_splits = tl.arange(0, num_splits_pow2) + offs_d = tl.arange(0, BLOCK_D) + + lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) + + lse_max = tl.max(lse) + + o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h + o_partial = tl.load( + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) + sumexp_normalized_splitk = tl.exp(lse - lse_max) + sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) + numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) + acc = numerator_normalized / sumexp_normalized + acc = acc.to(o_ptr.dtype.element_ty) + o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h + tl.store(o_ptr + offs_d * o_stride_d, acc) + + +def block_sparse_flash_decode_gqa_indice_triton( + q, + k_cache, + v_cache, + cache_seqlens, + max_cache_seqlen, + max_selected_blocks, + block_indices, + block_size, + sm_scale=None, +): + batch, heads, dim = q.shape + + if sm_scale is None: + sm_scale = 1 / math.sqrt(dim) + + _, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape + assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch" + group_size = heads // heads_kv + + block_H = 16 + + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 64 + # num_sm = self.num_sm + num_splits = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) + + num_splits_pow2 = triton.next_power_of_2(num_splits) + + o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype) + lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32) + + BLOCK_D = dim + BLOCK_H = group_size if group_size > 16 else 16 + grid = (batch, heads_kv, num_splits) + _split_kernel[grid]( + q, + k_cache, + v_cache, + cache_seqlens, + o_partial, + lse_partial, + block_indices, + sm_scale, + num_splits, + group_size, + max_selected_blocks, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + block_indices.stride(0), + block_indices.stride(1), + block_indices.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=block_size, + BLOCK_D=BLOCK_D, + ) + + output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype) + grid = (batch, heads) + _merge_kernel[grid]( + o_partial, + lse_partial, + output, + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_D=dim_v, + num_splits=num_splits, + num_splits_pow2=num_splits_pow2, + ) + + return output + + +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + dim_v = value.shape[-1] + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values based on block_indices + for b in range(batch): + for h in range(heads_kv): + valid_indices = block_indices[b, h] # Extract indices for this batch and head + for idx in valid_indices: + if idx >= 0: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, cache_seqlens): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + qk_flops = 2 * batch * heads * max_cache_seqlen * dim + pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v + total_flops = qk_flops + pv_flops + + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + block_H = 64 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + # Ensure at least one element equals cache_seqlen + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + + print("cache_seqlens: ", cache_seqlens) + + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_indices with -1 (for padding blocks) + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + if max_valid_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + # Sort indices within each batch-group for consistency + block_indices, _ = block_indices.sort(dim=-1, descending=True) + # print("block_indices: ", block_indices) + actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] + print("actual_num_blocks: ", actual_num_blocks) + # print(block_indices.shape, actual_num_blocks.shape) + + max_num_blocks = torch.max(max_valid_num_blocks).item() + print("max_num_blocks: ", max_num_blocks) + + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + + triton_out = block_sparse_flash_decode_gqa_indice_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + max_selected_blocks, + block_indices, + block_size, + ) + + print("max difference: ", torch.max(torch.abs(ref - triton_out))) + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + print("Passed the ref test!") + + # Measure performance + torch.cuda.synchronize() + start = time.time() + for _ in range(1000): + block_sparse_flash_decode_gqa_indice_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + max_selected_blocks, + block_indices, + block_size, + ) + torch.cuda.synchronize() + end = time.time() + elapsed_time = end - start + avg_time = elapsed_time / 1000 + avg_flops = total_flops / avg_time + print(f"Average time: {avg_time:.6f} seconds") + + # Measure performance of reference implementation + import flash_attn # noqa: F401 + + start = time.time() + for _ in range(1000): + ref_program_fa(Q, K, V, cache_seqlens) + torch.cuda.synchronize() + end = time.time() + elapsed_time_ref = end - start + avg_time_ref = elapsed_time_ref / 1000 + avg_flops_ref = total_flops / avg_time_ref + print(f"Average time of ref: {avg_time_ref:.6f} seconds") + + print(f"Speedup: {avg_time_ref / avg_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/tilelang/original/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/tilelang/original/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..c05b3777952fddc834cc46377a823a4c14e0e999 --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -0,0 +1,419 @@ +import torch +import triton +import triton.language as tl +import argparse +from einops import rearrange, einsum +import torch.nn.functional as F + +import math +import time +from heuristic import num_splits_heuristic + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], +) +@triton.jit +def _split_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + cache_seqlens_ptr, + o_partial_ptr, + lse_partial_ptr, + mask_ptr, + sm_scale, + num_splits, + gqa_group_size, + stride_q_b, + stride_q_h, + stride_q_d, + stride_k_b, + stride_k_s, + stride_k_h, + stride_k_d, + stride_v_b, + stride_v_s, + stride_v_h, + stride_v_d, + stride_o_b, + stride_o_h, + stride_o_split, + stride_o_d, + stride_lse_b, + stride_lse_h, + stride_lse_split, + stride_mask_b, + stride_mask_h, + stride_mask_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx_kv = tl.program_id(1) + split_idx = tl.program_id(2) + + head_idx_q = head_idx_kv * gqa_group_size + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx) + num_blocks = (cache_seqlens + BLOCK_N - 1) // BLOCK_N + blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32) + remaining_blocks = num_blocks % num_splits + if split_idx < remaining_blocks: + loop_range = blocks_per_split + 1 + else: + loop_range = blocks_per_split + + q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d + mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h + + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) + start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) + for block_idx in range(loop_range): + start_n = (start + block_idx) * BLOCK_N + mask_val = tl.load(mask_ptr + (start + block_idx) * stride_mask_s) + if mask_val == 1: + k_ptr = k_cache_ptr + start_n * stride_k_s + v_ptr = v_cache_ptr + start_n * stride_v_s + + k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0) + v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0) + + qk = tl.dot(q, k) + qk = qk * sm_scale + qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + m_i = m_ij + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(o_partial_ptr.dtype.element_ty) + + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) + + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) + tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], +) +@triton.jit +def _merge_kernel( + o_partial_ptr, + lse_partial_ptr, + o_ptr, + lse_partial_stride_b, + lse_partial_stride_h, + lse_partial_stride_split, + o_partial_stride_b, + o_partial_stride_h, + o_partial_stride_split, + o_partial_stride_d, + o_stride_b, + o_stride_h, + o_stride_d, + BLOCK_D: tl.constexpr, + num_splits: tl.constexpr, + num_splits_pow2: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + offs_splits = tl.arange(0, num_splits_pow2) + offs_d = tl.arange(0, BLOCK_D) + + lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) + + lse_max = tl.max(lse) + + o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h + o_partial = tl.load( + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) + sumexp_normalized_splitk = tl.exp(lse - lse_max) + sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) + numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) + acc = numerator_normalized / sumexp_normalized + acc = acc.to(o_ptr.dtype.element_ty) + o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h + tl.store(o_ptr + offs_d * o_stride_d, acc) + + +def block_sparse_flash_decode_gqa_mask_triton( + q, + k_cache, + v_cache, + cache_seqlens, + max_cache_seqlen, + block_mask, + block_size, + sm_scale=None, +): + batch, heads, dim = q.shape + + if sm_scale is None: + sm_scale = 1 / math.sqrt(dim) + + _, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape + assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch" + group_size = heads // heads_kv + + block_H = 16 + + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 64 + # num_sm = self.num_sm + num_splits = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) + + num_splits_pow2 = triton.next_power_of_2(num_splits) + + o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype) + lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32) + + BLOCK_D = dim + BLOCK_H = group_size if group_size > 16 else 16 + grid = (batch, heads_kv, num_splits) + _split_kernel[grid]( + q, + k_cache, + v_cache, + cache_seqlens, + o_partial, + lse_partial, + block_mask, + sm_scale, + num_splits, + group_size, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + block_mask.stride(0), + block_mask.stride(1), + block_mask.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=block_size, + BLOCK_D=BLOCK_D, + ) + + output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype) + grid = (batch, heads) + _merge_kernel[grid]( + o_partial, + lse_partial, + output, + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_D=dim_v, + num_splits=num_splits, + num_splits_pow2=num_splits_pow2, + ) + + return output + + +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values + for b in range(batch): + for h in range(heads_kv): + for idx in range(num_blocks): + if block_mask[b, h, idx]: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, cache_seqlens): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + block_size = block_size + sparse_ratio = sparse_ratio + qk_flops = 2 * batch * heads * max_cache_seqlen * dim + pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v + total_flops = qk_flops + pv_flops + + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # Ensure at least one element equals cache_seqlen + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + print("valid_num_blocks: ", valid_num_blocks) + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_mask with false (for padding blocks) + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch + if valid_num_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + + triton_out = block_sparse_flash_decode_gqa_mask_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + block_mask, + block_size, + ) + + # print("max difference: ", torch.max(torch.abs(ref - triton_out))) + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + print("Passed the ref test!") + + # Measure performance + torch.cuda.synchronize() + start = time.time() + for _ in range(1000): + block_sparse_flash_decode_gqa_mask_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + block_mask, + block_size, + ) + torch.cuda.synchronize() + end = time.time() + elapsed_time = end - start + avg_time = elapsed_time / 1000 + avg_flops = total_flops / avg_time + print(f"Average time: {avg_time:.6f} seconds") + print(f"Average flops: {avg_flops:.2f} GFLOPS") + + import flash_attn # noqa: F401 + + start = time.time() + for _ in range(1000): + ref_program_fa(Q, K, V, cache_seqlens) + + torch.cuda.synchronize() + end = time.time() + elapsed_time_ref = end - start + avg_time_ref = elapsed_time_ref / 1000 + avg_flops_ref = total_flops / avg_time_ref + print(f"Average time of ref: {avg_time_ref:.6f} seconds") + print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS") + + print(f"Speedup: {avg_time_ref / avg_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/tilelang/original/examples/blocksparse_attention/heuristic.py b/tilelang/original/examples/blocksparse_attention/heuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6fc528196e3f111924b7d16b34d0c9af8c3800 --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/heuristic.py @@ -0,0 +1,54 @@ +import math + + +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): + """ + Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. + + Parameters: + - total_mblocks (int): Total number of m_blocks. + - num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU. + - num_n_blocks (int): Number of n_blocks. + - num_m_blocks (int): Number of m_blocks. + - size_one_kv_head (int): Size of one KV head in bytes. + - is_causal_or_local (bool): Indicates whether the operation is causal or local. + - max_splits (int): Maximum number of allowed splits. + + Returns: + - int: The optimal number of splits. + """ + # If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply. + if total_mblocks >= 0.8 * num_SMs: + size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB) + # Only split if each KV head is too large for L2 and there are enough m_blocks + if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local: + return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits) + else: + return 1 + + # If num_n_blocks is too small, we don't split + if num_n_blocks <= 4: + return 1 + + # Limit max_splits to a reasonable range + max_splits = min(max_splits, num_SMs, num_n_blocks) + + max_efficiency = 0.0 + efficiency = [] + + # Compute efficiency for different splits + for num_splits in range(1, max_splits + 1): + n_waves = (total_mblocks * num_splits) / num_SMs + eff = n_waves / math.ceil(n_waves) + # Track max efficiency + if eff > max_efficiency: + max_efficiency = eff + + efficiency.append(eff) + + # Find the smallest number of splits that achieves at least 85% of max efficiency + for num_splits in range(1, max_splits + 1): + if efficiency[num_splits - 1] >= 0.85 * max_efficiency: + return num_splits + + return 1 diff --git a/tilelang/original/examples/blocksparse_attention/test_example_blocksparse_attention.py b/tilelang/original/examples/blocksparse_attention/test_example_blocksparse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..dd33f46c4ef9705350bc2cc8894cb715d4444346 --- /dev/null +++ b/tilelang/original/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -0,0 +1,39 @@ +import tilelang.testing +import block_sparse_attn_triton +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask +import example_triton_sparse_gqa_decode_varlen_indice +import example_triton_sparse_gqa_decode_varlen_mask + + +def test_block_sparse_attn_triton(): + block_sparse_attn_triton.main() + + +def test_example_tilelang_block_sparse_attn(): + example_tilelang_block_sparse_attn.main() + + +def test_example_tilelang_sparse_gqa_decode_varlen_indice(): + example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048) + + +def test_example_tilelang_sparse_gqa_decode_varlen_mask(): + example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048) + + +def test_example_triton_sparse_gqa_decode_varlen_indice(): + example_triton_sparse_gqa_decode_varlen_indice.main( + batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) + + +def test_example_triton_sparse_gqa_decode_varlen_mask(): + example_triton_sparse_gqa_decode_varlen_mask.main( + batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/blocksparse_gemm/example_blocksparse_gemm.py b/tilelang/original/examples/blocksparse_gemm/example_blocksparse_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a34e45de7c31f4594f76127cea577306c7554e --- /dev/null +++ b/tilelang/original/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -0,0 +1,179 @@ +import argparse +import itertools +import tilelang +import tilelang.language as T +from tilelang.engine.param import KernelParam +from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType +import torch +from typing import List + +DEFAULT_BLOCK_M = 128 +DEFAULT_BLOCK_N = 128 +DEFAULT_BLOCK_K = 32 +DEFAULT_NUM_STAGES = 2 +DEFAULT_THREAD_NUM = 128 +DEFAULT_ENABLE_RASTERIZATION = True + +parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark") +parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") +parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") +parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") +parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") +parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") + +args, _ = parser.parse_known_args() +M, N, K = args.m, args.n, args.k +sparsity = args.sparsity +use_autotune = args.use_autotune +default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto) + +print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}") +print(f"Target Block Sparsity: {sparsity}") +print(f"Using Autotuner: {use_autotune}\n") + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) + + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], + } + for c in _configs + ] + + +def ref_program(A, B, BlockMask, block_M, block_N, block_K): + ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + for i in range(M // block_M): + for j in range(N // block_N): + accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) + for k in range(K // block_K): + if BlockMask[i, j, k]: + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) + return ref_c + + +def supply_program(params: List[KernelParam]): + input_tensors = [] + + for p in params: + # Check if the kernel parameter is BlockMask tensor. + # Here, BlockMask is uniquely identified by having 3 dimensions. + if len(p.shape) != 3: + # For non-BlockMask tensors, use the default tensor generation logic. + input_tensors.append(default_tensor_supply(p)) + else: + # For BlockMask tensor, randomly set elements to True based on desired + # sparsity level. + block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device()) + block_mask[:, :, :] = torch.rand(p.shape) > sparsity + input_tensors.append(block_mask) + + return input_tensors + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def blocksparse_matmul( + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): + block_mask_shape = (M // block_M, N // block_N, K // block_K) + + @T.prim_func + def block_sparse_matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if BlockMask[by, bx, k]: + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return block_sparse_matmul + + +def main(): + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + if args.use_autotune: + # Run the autotuner to find the best kernel configuration and performance + # get_best_config is expected to return an object containing the compiled kernel, + # the best configuration found, latency, and reference latency. + kernel = blocksparse_matmul(M, N, K) + + best_config = kernel.config + best_latency = kernel.latency + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"] + + print(f"Best Config: {best_config}") + print(f"Sparsity Ratio: {sparsity}") + print(f"Best Kernel Latency: {best_latency:.6f} ms") + else: + kernel = blocksparse_matmul( + M, + N, + K, + block_M=DEFAULT_BLOCK_M, + block_N=DEFAULT_BLOCK_N, + block_K=DEFAULT_BLOCK_K, + num_stages=DEFAULT_NUM_STAGES, + thread_num=DEFAULT_THREAD_NUM, + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) + block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K + print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + try: + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("✅ Results are close! Verification successful.") + except AssertionError as e: + print("❌ Verification FAILED: Results differ significantly.") + print(e) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/blocksparse_gemm/test_example_blocksparse_gemm.py b/tilelang/original/examples/blocksparse_gemm/test_example_blocksparse_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b39f5e882b454392d9d7a380b923320e9cbbea --- /dev/null +++ b/tilelang/original/examples/blocksparse_gemm/test_example_blocksparse_gemm.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_blocksparse_gemm + + +def test_example_blocksparse_gemm(): + example_blocksparse_gemm.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/cast/example_group_per_split_token_cast_to_fp8.py b/tilelang/original/examples/cast/example_group_per_split_token_cast_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..6bde50c512ad038424e99235edf7ca44abb2d853 --- /dev/null +++ b/tilelang/original/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -0,0 +1,209 @@ +import torch +import tilelang +import tilelang.language as T +from typing import Tuple +from tilelang.utils.tensor import torch_assert_close + +# support bfloat16, float, float16 +dtype = T.bfloat16 +accum_dtype = T.float32 + + +@tilelang.jit(out_idx=[2, 3]) +def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): + group_size = 128 + fp8_min = -448.0 + fp8_max = 448.0 + + @T.prim_func + def group_per_split_token_cast( + X: T.Tensor((M, N), dtype), + batch_sizes: T.Tensor((BG,), T.int32), + X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn), + X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), + ): + with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): + row = bx + row_g_id = by + bg = bz + y_local = T.alloc_fragment((blk_m, group_size), accum_dtype) + y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) + y_s_local = T.alloc_fragment((blk_m,), accum_dtype) + y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + row_offset = T.alloc_fragment((1,), T.int32) + + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) + + row_offset[0] = 0 + for i in T.serial(bg): + row_offset[0] += batch_sizes[i] + + T.copy( + X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], + y_local, + ) + T.reduce_absmax(y_local, y_amax_local, dim=1) + for i in T.Parallel(blk_m): + y_amax_local[i] = T.max(y_amax_local[i], 1e-4) + y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0) + for i, j in T.Parallel(blk_m, group_size): + y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) + T.copy(y_q_local, y_q_local_fp8) + for i, j in T.Parallel(blk_m, group_size): + y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0) + for i in T.Parallel(blk_m): + X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] + T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) + + return group_per_split_token_cast + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return ceil_div(x, alignment) * alignment + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along the M axis + (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dim() in (2, 3) + remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # this function don't support cpu tensor + assert x.dim() == 2 + m, n = x.shape + new_n = ceil_div(n, 128) * 128 + x_padded = torch.nn.functional.pad(x, (0, new_n - n)) + x_view = x_padded.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() + return x_fp8, (x_amax / 448.0).view(m, -1) + + +def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # assert x.shape[0] == batch_sizes.sum() + M_max = ceil_div(batch_sizes.max(), 128) * 128 + split_x = torch.split(x, batch_sizes.tolist(), dim=0) + padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] + num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), + ) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8 + + +def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == T.float: + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + elif dtype == T.float16: + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + elif dtype == T.bfloat16: + x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) + M_max = int(ceil_div(batch_sizes.max(), 128) * 128) + + print("batch_sizes:", batch_sizes) + print("M_max:", M_max) + + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) + print(kernel.get_kernel_source()) + # profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + x_fp8, x_amax = kernel(x, batch_sizes) + x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + print("All checks pass.") + + from tilelang.profiler import do_bench + + def run_tilelang(): + x_fp8_tilelang_, x_amax_tilelang_ = kernel(x, batch_sizes) + return x_fp8_tilelang_, x_amax_tilelang_ + + def run_torch(): + x_fp8_torch_, x_amax_torch_ = ref_program(x, batch_sizes) + return x_fp8_torch_, x_amax_torch_ + + latency = do_bench(run_tilelang) + print("Tile-lang: {:.2f} ms".format(latency)) + + latency = do_bench(run_torch) + print("Torch: {:.2f} ms".format(latency)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/cast/example_per_token_cast_to_fp8.py b/tilelang/original/examples/cast/example_per_token_cast_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6d14884039d228e2d2082662e952571622db20 --- /dev/null +++ b/tilelang/original/examples/cast/example_per_token_cast_to_fp8.py @@ -0,0 +1,113 @@ +import torch +import tilelang +import tilelang.language as T +from typing import Tuple +from tilelang.utils.tensor import torch_assert_close + + +@tilelang.jit(out_idx=[1, 2]) +def per_token_cast_to_fp8(M, N, blk_m): + dtype = T.float + group_size = 128 + fp8_min = -448.0 + fp8_max = 448.0 + + @T.prim_func + def per_token_cast( + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + ): + with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): + row = bx + row_g_id = by + y_local = T.alloc_fragment((blk_m, group_size), dtype) + y_amax_local = T.alloc_fragment((blk_m,), dtype) + y_s_local = T.alloc_fragment((blk_m,), dtype) + y_q_local = T.alloc_fragment((blk_m, group_size), dtype) + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) + + T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local) + T.reduce_absmax(y_local, y_amax_local, dim=1) + for i in T.Parallel(blk_m): + y_amax_local[i] = T.max(y_amax_local[i], 1e-4) + y_s_local[i] = y_amax_local[i] / fp8_max + for i, j in T.Parallel(blk_m, group_size): + y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) + T.copy(y_q_local, y_q_local_fp8) + for i in T.Parallel(blk_m): + X_amax[row * blk_m + i, row_g_id] = y_s_local[i] + T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) + + return per_token_cast + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # this function don't support cpu tensor + assert x.dim() == 2 + m, n = x.shape + new_n = ceil_div(n, 128) * 128 + x_padded = torch.nn.functional.pad(x, (0, new_n - n)) + x_view = x_padded.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() + return x_fp8, (x_amax / 448.0).view(m, -1) + + +def main(M=8192, N=8192, blk_m=8): + kernel = per_token_cast_to_fp8(M, N, blk_m) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + + x_fp8, x_amax = kernel(x) + x_fp8_ref, x_amax_ref = ref_program(x) + + print("x_fp8:", x_fp8, x_fp8.shape) + print("x_amax:", x_amax, x_amax.shape) + print("x_fp8_ref:", x_fp8_ref, x_fp8_ref.shape) + print("x_amax_ref:", x_amax_ref, x_amax_ref.shape) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + latency = profiler.do_bench() + print("Tile-lang: {:.2f} ms".format(latency)) + + from tilelang.profiler import do_bench + from example_triton_cast_to_fp8 import per_token_group_quant_fp8 + + def run_triton(): + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + return x_fp8_triton_, x_amax_triton_ + + x_fp8_triton, x_amax_triton = run_triton() + latency = do_bench(run_triton) + print("Triton: {:.2f} ms".format(latency)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/cast/example_triton_cast_to_fp8.py b/tilelang/original/examples/cast/example_triton_cast_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..1859433f10b6f6bd438846473b5661718c34fe4f --- /dev/null +++ b/tilelang/original/examples/cast/example_triton_cast_to_fp8.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/tilelang/original/examples/cast/test_example_cast.py b/tilelang/original/examples/cast/test_example_cast.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b10a7979cf6506ec93c21bc8e9d3ddec2cc214 --- /dev/null +++ b/tilelang/original/examples/cast/test_example_cast.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_group_per_split_token_cast_to_fp8 +import example_per_token_cast_to_fp8 + + +def test_example_group_per_split_token_cast_to_fp8(): + example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) + + +def test_example_per_token_cast_to_fp8(): + example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/conftest.py b/tilelang/original/examples/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..4010e0d83ae84c641151d6dd56dbf40ee42e301f --- /dev/null +++ b/tilelang/original/examples/conftest.py @@ -0,0 +1,41 @@ +import os +import random +import pytest + +os.environ["PYTHONHASHSEED"] = "0" + +random.seed(0) + +try: + import torch +except ImportError: + pass +else: + torch.manual_seed(0) + +try: + import numpy as np +except ImportError: + pass +else: + np.random.seed(0) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """Ensure that at least one test is collected. Error out if all tests are skipped.""" + known_types = { + "failed", + "passed", + "skipped", + "deselected", + "xfailed", + "xpassed", + "warnings", + "error", + } + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: + terminalreporter.write_sep( + "!", + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + ) + pytest.exit("No tests were collected.", returncode=5) diff --git a/tilelang/original/examples/convolution/README.md b/tilelang/original/examples/convolution/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ddca8a6aeb369e20c9e18d463b06f755ccb9221 --- /dev/null +++ b/tilelang/original/examples/convolution/README.md @@ -0,0 +1 @@ +# Convolution diff --git a/tilelang/original/examples/convolution/example_convolution.py b/tilelang/original/examples/convolution/example_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd3972fb081ec57e6c569f1d42e28db0caa55be --- /dev/null +++ b/tilelang/original/examples/convolution/example_convolution.py @@ -0,0 +1,111 @@ +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + + +def check_hopper(): + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def ref_program(stride, padding, dilation): + def main(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + return main + + +@tilelang.jit(out_idx=[2]) +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = T.float16 + accum_dtype = T.float32 + is_hopper = check_hopper() + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + + out_c = kernel(a, b) + ref_c = ref_program(S, P, D)(a, b) + torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/convolution/example_convolution_autotune.py b/tilelang/original/examples/convolution/example_convolution_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..59588ac4fbd40db9e1c3d45ab0ff105af28ce004 --- /dev/null +++ b/tilelang/original/examples/convolution/example_convolution_autotune.py @@ -0,0 +1,177 @@ +import torch +import argparse +import itertools +import tilelang +import tilelang.language as T + + +def check_hopper(): + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def ref_program(stride, padding, dilation): + def main(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + return main + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def get_heuristic_config() -> dict: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version in {80}: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} + elif sm_version in {90}: + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} + else: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[2]) +def convolution( + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = T.float16 + accum_dtype = T.float32 + is_hopper = check_hopper() + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + if is_hopper: + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + if is_hopper: + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + else: + T.copy(out_local, out_flat[by * block_M, bx * block_N]) + + return main + + +def main( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): + N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p + ref_prog = ref_program(S, P, D) + + if use_autotune: + kernel = convolution(N, C, H, W, F, K, S, D, P) + else: + config = get_heuristic_config() + kernel = convolution(N, C, H, W, F, K, S, D, P, **config) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench() + ref_latency = profiler.do_bench(ref_prog) + profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2) + print(f"TileLang latency: {tilelang_latency}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space") + args = parser.parse_args() + main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller) diff --git a/tilelang/original/examples/convolution/test_example_convolution.py b/tilelang/original/examples/convolution/test_example_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4c06fb0044e0a0d531bf9abeffe484d3d48acfa1 --- /dev/null +++ b/tilelang/original/examples/convolution/test_example_convolution.py @@ -0,0 +1,21 @@ +import tilelang.testing + +import example_convolution +import example_convolution_autotune + + +# TODO(@cy): TMA with convolution must be fixed in future. +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_convolution(): + example_convolution.main([]) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_convolution_autotune(): + example_convolution_autotune.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/tilelang/original/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py new file mode 100644 index 0000000000000000000000000000000000000000..18467a811898d20813b8ed1ac6b9838fd5efe59d --- /dev/null +++ b/tilelang/original/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -0,0 +1,186 @@ +from typing import Tuple + +import torch +import tilelang.testing +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(42) + + +@tilelang.jit +def tl_gemm( + M, + N, + K, + block_N, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float8_e4m3fn, + ], "Currently only float8_e4m3 is supported" + assert out_dtype in [ + T.bfloat16, + T.float32, + ], "Currently only float16 and float32 are supported" + + group_size = 128 + block_M = 128 + block_K = 128 + + A_shape = (M, K) + Scales_A_shape = (M, T.ceildiv(K, group_size)) + B_shape = (N, K) + Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size)) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = (block_M, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + scales_a: T.Tensor(Scales_A_shape, T.float32), + scales_b: T.Tensor(Scales_B_shape, T.float32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + Scale_C_shared = T.alloc_shared((block_M), T.float32) + C_local = T.alloc_fragment(C_shared_shape, accum_dtype) + C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def ceildiv(a, b): + return (a + b - 1) // b + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): + # A_scale: (M, K//128) ==> (M//128, K//128, 128) + # B_scale: (N//128, K//128) ==> (N//128, K//128, 128) + # A_fp8: (M, K) + # B_fp8: (N, K) + # out_dtype: float16 or float32 + # return C: (M, N) + M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1] + A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1) + B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128) + C = torch.zeros(M, N, device="cuda", dtype=out_dtype) + c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32) + for i in range(ceildiv(M, 128)): + for j in range(ceildiv(N, 128)): + c_acc.zero_() + for k in range(ceildiv(K, 128)): + c = torch._scaled_mm( + A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], + B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, + scale_a=A_scales[i, k].view(128, 1).contiguous(), + scale_b=B_scales[j, k].view(1, 128).contiguous(), + out_dtype=torch.bfloat16, + ) + c_acc += c.to(torch.float32) + C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) + return C + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype): + kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + A = torch.randn(M, K).to(torch.bfloat16).cuda() + B = torch.randn(N, K).to(torch.bfloat16).cuda() + A_fp8, A_scale = per_token_cast_to_fp8(A.clone()) + B_fp8, B_scale = per_block_cast_to_fp8(B.clone()) + + C = torch.zeros(M, N, device="cuda", dtype=out_dtype) + + kernel(A_fp8, B_fp8, C, A_scale, B_scale) + # Get Reference Result + ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype) + diff = calc_diff(C, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + profiler = kernel.get_profiler() + latency = profiler.do_bench(warmup=25) + # Ensure that the latency is not None + assert latency is not None + print(f"latency: {latency} ms") + tflops = 2 * M * N * K / latency / 1e9 + print(f"tflops: {tflops}") + + +def main(): + assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32) + + +if __name__ == "__main__": + for dtype in [T.float8_e4m3fn]: + for out_dtype in [T.bfloat16, T.float32]: + for block_N in [16, 32, 64, 128]: + assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32) diff --git a/tilelang/original/examples/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py b/tilelang/original/examples/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py new file mode 100644 index 0000000000000000000000000000000000000000..c3dac38af95a572a07e55bcc74e2fe8e2d750d6a --- /dev/null +++ b/tilelang/original/examples/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py @@ -0,0 +1,13 @@ +import tilelang.testing + +from example_deepgemm_fp8_2xAcc import main + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_deepgemm_fp8_2xAcc(): + main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/deepseek_mla/README.md b/tilelang/original/examples/deepseek_mla/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e64b1c37d002559ea8313706340b48532a0c0b61 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/README.md @@ -0,0 +1,140 @@ +# 🚀 How to write high-performance kernel with TileLang: take MLA as an example + +TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang. + +## Introduction to MLA + +DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance. + +## Benchmark Results + +We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below. + +
+ + bs64_float16 + +
Figure 1:Performance under batch size=64
+
+ +
+ + bs128_float16 + +
Figure 2:Performance under batch size=128
+
+ +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. + +## Implementation + +First, let's review the core computation logic of traditional FlashAttention: + +```python +# acc_s: [block_M, block_N] +# scores_max: [block_M] +# scores_scale: [block_M] +# acc_o: [block_M, dim] + +for i in range(loop_range): + acc_s = Q @ K[i] + scores_max_prev = scores_max + scores_max = max(acc_s, dim=1) + scores_scale = exp(scores_max_prev - scores_max) + acc_o *= scores_scale + acc_s = exp(acc_s - scores_max) + acc_o = acc_s @ V[i] + ... +``` + +Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency. + +Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance. + +This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. + +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. + +Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. + +### Layout Inference + +While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you. + +Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations. + +
+ + QK Layout + +
Figure 3:Buffer shapes in Q @ K
+
+ +
+ + PV Layout + +
Figure 4:Buffer shapes in acc_s @ V
+
+ +The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA. + +For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`. + +It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance. + +### Threadblock Swizzling + +Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions. + +In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code: + +```python +T.use_swizzle(panel_size: int, order: str = "row") +``` + +Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". + + +### Shared Memory Swizzling + +In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. + +One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency. + +Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code: + +```python +T.annotate_layout({ + S_shared: TileLang.layout.make_swizzled_layout(S_shared), +}) +``` + +Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. + + +### Warp-Specialization + +The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. + +In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. + + +### Pipeline + + +Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: + +```python +T.pipelined(range: int, stage: int) +``` + +Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. + + +### Split-KV + +We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. + +In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. \ No newline at end of file diff --git a/tilelang/original/examples/deepseek_mla/amd/README.md b/tilelang/original/examples/deepseek_mla/amd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cc0fb576dce7e77e60d3cf5eedda2732dc3f546a --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/amd/README.md @@ -0,0 +1,52 @@ +# 🚀 High-Performance FlashMLA Implementation Using TileLang on AMD MI300X Accelerators + +Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms. + +## Architectural Considerations and Optimization Strategies + +Key implementation differences between Hopper and MI300X architectures include: + +1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations. + +2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes: + - Reducing software pipeline stages + - Register-based caching of Q matrices instead of shared memory utilization: + ```python + # Original shared memory allocation + Q_shared = T.alloc_shared([block_H, dim], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + + # Optimized register allocation + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + ``` + +3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64. + +4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code. + +## Performance Evaluation + +We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate: + +
+ + AMD FlashMLA Performance Comparison + +
Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)
+
+ +Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) (from 0.73x to 1.21x) in most test cases, while significantly outperforming Triton (up to 6.5x faster)implementations. This performance is achieved through a concise 70-line Python implementation! + +## Future Optimization Opportunities + +1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction. + +2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to: + - Reduce shared memory pressure + - Improve compute-to-memory access ratios + - Enhance parallelism through dimension-wise task distribution + +## Acknowledgment + +We would like to express our sincere gratitude to the AMD ROCm and Composable Kernel team for their outstanding contributions. We have learned a great deal from the ROCm software stack. diff --git a/tilelang/original/examples/deepseek_mla/amd/autotuner.log b/tilelang/original/examples/deepseek_mla/amd/autotuner.log new file mode 100644 index 0000000000000000000000000000000000000000..c1d7b03eb31489122a96017fc65458fe4eff0355 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/amd/autotuner.log @@ -0,0 +1 @@ +2026-03-25 09:34:07,764 WARNING:Tunable parameters ['block_N', 'block_H', 'num_split', 'threads'] already provided during auto-tuning. Skipping compilation and using direct JIT diff --git a/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..a9035793b9f305ee330185db329e615710488635 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +def get_configs(): + import itertools + + BLOCK_N = [16, 32, 64, 128] + BLOCK_H = [16, 32, 64, 128] + num_split = [1, 2, 4, 8, 16, 32] + threads = [128, 256] + + _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) + + return [ + { + "block_N": c[0], + "block_H": c[1], + "num_split": c[2], + "threads": c[3], + } + for c in _configs + ] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=0): + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # T.copy(acc_s, S_shared) + T.copy(acc_s, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz): + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=0): + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + parser.add_argument("--autotune", action="store_true", help="auto tune") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + enable_autotune = args.autotune + + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 32 + BLOCK_H = 64 + num_split = 4 + threads = 128 + + if enable_autotune: + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) + else: + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + input_tensors = profiler._get_inputs() + tilelang_output = kernel(*input_tensors) + ref_output = ref_program(*input_tensors) + print(f"Tilelang output: {tilelang_output}") + print(f"Ref output: {ref_output}") + torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") diff --git a/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py b/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..18c0a5f86d7625af022832d36f58b123c0feb0f8 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py @@ -0,0 +1,512 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + num_stages=1, # 2 will oom in amd + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "flash_mla_triton": run_flash_mla_triton, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flash_mla_triton"]: + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "flash_mla_triton", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="torch") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..861e841c4ec8b68851cd4bfdbfdce0fede87960f --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -0,0 +1,509 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + num_stages=1, # 2 will oom in amd + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "flash_mla_triton": run_flash_mla_triton, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flash_mla_triton"]: + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "flash_mla_triton", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="flash_mla_triton") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/tilelang/original/examples/deepseek_mla/benchmark_mla.py b/tilelang/original/examples/deepseek_mla/benchmark_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..544b5e1285c173e1521f049e1de9521baa53afee --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/benchmark_mla.py @@ -0,0 +1,628 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench +from example_mla_decode_paged import mla_decode_tilelang + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@torch.inference_mode() +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + from flash_mla import flash_mla_with_kvcache, get_mla_metadata + + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + out_flash, lse_flash = flash_mla() + t = triton.testing.do_bench(flash_mla) + return out_flash, lse_flash, t + + +@torch.inference_mode() +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + # pip install flashinfer-python + import flashinfer + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + kv_indptr = [0] + kv_indices = [] + for i in range(b): + seq_len = cache_seqlens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_table[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + for seq_len in cache_seqlens[1:]: + kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + + q_indptr = torch.arange(0, b + 1).int() * s_q + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + cache_seqlens, + h_q, + dv, + d - dv, + block_size, + causal, + 1 / math.sqrt(d), + q.dtype, + blocked_k.dtype, + ) + + def flashinfer(): + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) + return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) + + out_flash, lse_flash = flashinfer() + t = triton.testing.do_bench(flashinfer) + return out_flash, lse_flash, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + num_warps=4, + num_stages=2, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +@torch.inference_mode() +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = 64 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) + + def flash_mla_tilelang(): + out = kernel( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "tilelang": run_flash_mla_tilelang, + "flash_mla": run_flash_mla, + "flashinfer": run_flashinfer, + "flash_mla_triton": run_flash_mla_triton, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + # flashinfer has a different lse return value + # flash_mla_triton and flash_mla_tilelang doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "tilelang", + "flash_mla", + "flashinfer", + "flash_mla_triton", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="tilelang") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/tilelang/original/examples/deepseek_mla/example_mla_decode.py b/tilelang/original/examples/deepseek_mla/example_mla_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..0d141b4b39500f7860a3308158e5bc5c30d4fb5d --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/example_mla_decode.py @@ -0,0 +1,301 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): + scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = hid // (kv_group_num // block_H) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=2): + T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = hid // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, hid, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split[0] = glse[bz, hid, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/tilelang/original/examples/deepseek_mla/example_mla_decode_paged.py b/tilelang/original/examples/deepseek_mla/example_mla_decode_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..23001bde8a1fc7846d45f406d9e0db4d95252ece --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/example_mla_decode_paged.py @@ -0,0 +1,378 @@ +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse +from tilelang.profiler import do_bench +import math + + +@tilelang.jit( + out_idx=[8], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None): + if softmax_scale is None: + softmax_scale = (dv + dpe) ** -0.5 + scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = h_q // h_kv + VALID_BLOCK_H = min(block_H, kv_group_num) + assert h_kv == 1, "h_kv must be 1" + assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N" + + @T.macro + def flash_mla_kernel( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + CACHE_SEQLENS: T.Tensor([batch], T.int32), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) + for kr in T.Pipelined(loop_range, num_stages=2): + k = loop_range - 1 - kr + kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + if kr == 0: + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_mla_split_kv_kernel( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + CACHE_SEQLENS: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) + blocks_per_split = T.floordiv(total_blocks, num_split) + remaining_blocks = T.floormod(total_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) + start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N + + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + with T.Kernel(h_q, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dv], dtype) + o_accum_local = T.alloc_fragment([dv], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dv): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dv): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dv): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + # q: [b, s_q, h_q, d] + # block_table: [b, max_seqlen_pad // block_size] + # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] + # cache_seqlens: [b] + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out.to(dtype), lse.to(dtype) + + out_torch, _ = ref_mla() + return out_torch + + +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = min(64, h_q // h_kv) + softmax_scale = d**-0.5 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + def flash_mla_tilelang(): + out = profiler.func( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) + print("All close") + return out_flash, t + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=128, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") + args = parser.parse_args() + b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv + + device = "cuda" + dtype = torch.float16 + + s_q = 1 # for decode, s_q = 1 + block_size = 64 + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) + dpe = d - dv + causal = True + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 + + total_flops = s_q * total_seqlens * h_q * d * 2 + + q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) + out_flash, latency = run_tilelang_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/tilelang/original/examples/deepseek_mla/example_mla_decode_persistent.py b/tilelang/original/examples/deepseek_mla/example_mla_decode_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a1300a239b57de9ac82c4f1705bea20920e873 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/example_mla_decode_persistent.py @@ -0,0 +1,209 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from tilelang.carver.arch import driver +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + sm_num = driver.get_num_sms() + + @T.prim_func + def main_split_persistent( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(sm_num, threads=256) as (block_id): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + # O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + T.use_swizzle(10) + + total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split + waves = T.ceildiv(total_tiles, sm_num) + for w in T.serial(waves): + tile_id = sm_num * w + block_id + bid = tile_id // ((heads // min(block_H, kv_group_num)) * num_split) + hid = tile_id // num_split % (heads // min(block_H, kv_group_num)) + sid = tile_id % num_split + cur_kv_head = hid // (kv_group_num // block_H) + + if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = (seqlen_kv // num_split) * sid + k * block_N + kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid]) + # T.copy(acc_o, O_shared) + T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :]) + + T.sync_grid() + waves = T.ceildiv(heads * batch, sm_num) + for w in T.serial(waves): + tile_id = sm_num * w + block_id + hid = tile_id // batch + bid = tile_id % batch + if bid < batch and hid < heads: + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bid, hid, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bid, hid, k, i] + lse_local_split[0] = glse[bid, hid, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bid, hid, i] = o_accum_local[i] + + return main_split_persistent + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = 64 + num_split = 2 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/deepseek_mla/example_mla_decode_ws.py b/tilelang/original/examples/deepseek_mla/example_mla_decode_ws.py new file mode 100644 index 0000000000000000000000000000000000000000..8e317fa00183e39581f460ff6159efa4aefb7d52 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/example_mla_decode_ws.py @@ -0,0 +1,606 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + ], +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): + sm_scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): + Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) + Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) + Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared_0_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_0_r = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_r = T.alloc_shared([block_N, dim // 2], dtype) + K_tail_shared_0 = T.alloc_shared([block_N, pe_dim], dtype) + K_tail_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + acc_o_l = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_o_r = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + sumexp = T.alloc_fragment([block_H], accum_dtype) + sum_exp_shared = T.alloc_shared([block_H], accum_dtype) + sumexp_i = T.alloc_fragment([block_H], accum_dtype) + alpha_shared = T.alloc_shared([block_H], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([block_H], accum_dtype) + m_i = T.alloc_fragment([block_H], accum_dtype) + m_i_prev = T.alloc_fragment([block_H], accum_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + cur_kv_head = hid // (kv_group_num // block_H) + NI = T.ceildiv((seqlen_kv // num_split), block_N) + + tx = T.get_thread_binding() + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, out=m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(block_H): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(block_H): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim]) + + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + @T.macro + def flash_attn_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz): + Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) + Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) + Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared_0_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_0_r = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_r = T.alloc_shared([block_N, dim // 2], dtype) + K_tail_shared_0 = T.alloc_shared([block_N, pe_dim], dtype) + K_tail_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + acc_o_l = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_o_r = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + sumexp = T.alloc_fragment([block_H], accum_dtype) + sum_exp_shared = T.alloc_shared([block_H], accum_dtype) + sumexp_i = T.alloc_fragment([block_H], accum_dtype) + alpha_shared = T.alloc_shared([block_H], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([block_H], accum_dtype) + m_i = T.alloc_fragment([block_H], accum_dtype) + m_i_prev = T.alloc_fragment([block_H], accum_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + cur_kv_head = hid // (kv_group_num // block_H) + NI = T.ceildiv((seqlen_kv // num_split), block_N) + + tx = T.get_thread_binding() + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(block_H): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(block_H): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2]) + T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim]) + + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, hid, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split[0] = glse[bz, hid, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/tilelang/original/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/tilelang/original/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..fa39fa498f552c9409dfbd313a85a985e710c054 --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -0,0 +1,150 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + q_dtype = T.float8_e4m3fn + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + qKV_shared = T.alloc_shared([block_N, dim], q_dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.disable_warp_group_reg_alloc() + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=2): + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(qKV_shared, KV_shared) + + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) + + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = 64 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") diff --git a/tilelang/original/examples/deepseek_mla/figures/bs128_float16.png b/tilelang/original/examples/deepseek_mla/figures/bs128_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..3cf24c84b82532bf422efee26afe61b4ae0e1948 Binary files /dev/null and b/tilelang/original/examples/deepseek_mla/figures/bs128_float16.png differ diff --git a/tilelang/original/examples/deepseek_mla/figures/bs64_float16.png b/tilelang/original/examples/deepseek_mla/figures/bs64_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..15807c3d2e57f5a2848b792d0fe746db31be455d Binary files /dev/null and b/tilelang/original/examples/deepseek_mla/figures/bs64_float16.png differ diff --git a/tilelang/original/examples/deepseek_mla/figures/flashmla-amd.png b/tilelang/original/examples/deepseek_mla/figures/flashmla-amd.png new file mode 100644 index 0000000000000000000000000000000000000000..75470bb30184b866402124fe1917eb7591623a7e Binary files /dev/null and b/tilelang/original/examples/deepseek_mla/figures/flashmla-amd.png differ diff --git a/tilelang/original/examples/deepseek_mla/figures/pv_layout.jpg b/tilelang/original/examples/deepseek_mla/figures/pv_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..79b0c8cf301d9c04eef050c893156c71549ce03d Binary files /dev/null and b/tilelang/original/examples/deepseek_mla/figures/pv_layout.jpg differ diff --git a/tilelang/original/examples/deepseek_mla/figures/qk_layout.jpg b/tilelang/original/examples/deepseek_mla/figures/qk_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d5bd923d0d8ab1fe5edece222f31777ccd0d746 Binary files /dev/null and b/tilelang/original/examples/deepseek_mla/figures/qk_layout.jpg differ diff --git a/tilelang/original/examples/deepseek_mla/test_example_mla_decode.py b/tilelang/original/examples/deepseek_mla/test_example_mla_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..a269ea57aed102b83596d4c7a896322fb105fbfb --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/test_example_mla_decode.py @@ -0,0 +1,12 @@ +import tilelang.testing +import example_mla_decode + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mla_decode(): + example_mla_decode.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/deepseek_mla/torch_refs.py b/tilelang/original/examples/deepseek_mla/torch_refs.py new file mode 100644 index 0000000000000000000000000000000000000000..aae6c7cd2b619afee90f39058cfd9a4a6a71e49e --- /dev/null +++ b/tilelang/original/examples/deepseek_mla/torch_refs.py @@ -0,0 +1,81 @@ +import torch + +num_split = 1 + + +def flash_split_ref(Q, Q_pe, KV, K_pe): + dim = Q.shape[-1] + pe_dim = Q_pe.shape[-1] + batch = Q.size(0) + nheads = Q.size(1) + block_N = 64 + seqlen_kv = KV.size(1) + + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) + + Q_ = Q * scale + Q_pe_ = Q_pe * scale + KV_ = KV.expand(-1, -1, nheads, -1) + K_pe_ = K_pe.expand(-1, -1, nheads, -1) + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum( + "bhd,bkhd->bhk", + Q_, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] + acc_s += torch.einsum( + "bhd,bkhd->bhk", + Q_pe_, + K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] + scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] + acc_o *= scores_scale[:, :, None] + acc_s = torch.exp2(acc_s - scores_max[:, :, None]) + acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] + acc_o += torch.einsum( + "bhk,bkhd->bhd", + acc_s_cast, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, None] + logsum = torch.log2(logsum) + scores_max + gacc_o[ks, :, :, :] = acc_o + glogsum[ks, :, :] = logsum + + return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3) + + +def reduce_ref(Q, Q_pe, KV, K_pe, glse, Output_partial): + o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks] + scale = torch.exp2(lse - lse_logsum) + o += Output_partial[:, :, ks, :] * scale[:, :, None] + return o.to(torch.float16) diff --git a/tilelang/original/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/tilelang/original/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..dadb4b4cb916ebc6f8ebce9120071be10a55e5e4 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -0,0 +1,954 @@ +# ruff: noqa + +import torch +import time +import argparse +import tilelang +from tilelang import language as T +import tilelang.testing +from typing import Optional, Union +from einops import rearrange, repeat +import triton +import triton.language as tl +from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_fwd, contiguous + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o_slc: torch.Tensor, + o_swa: Optional[torch.Tensor], + lse_slc: torch.Tensor, + lse_swa: Optional[torch.Tensor], + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None`. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + S = block_indices.shape[-1] + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + if isinstance(block_counts, torch.Tensor): + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o_slc = torch.zeros_like(v) + o_swa = torch.zeros_like(v) if window_size > 0 else None + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[i] + else: + s_b = block_counts + else: + T = cu_seqlens[i + 1] - cu_seqlens[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] + else: + s_b = block_counts + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [HQ] + g_slc_i = g_slc_b[i_q] + # [HQ] + g_swa_i = g_swa_b[i_q] + # [S*BS, HQ] + i_i = i_b[i_q] + # [HQ] + if isinstance(block_counts, torch.Tensor): + s_i = s_b[i_q] + else: + s_i = s_b + # [S*BS, HQ, -1] + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) + if not varlen: + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + else: + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + if window_size > 0: + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) + if not varlen: + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + else: + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + + if head_first: + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") + + return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) + + +def get_configs(): + import itertools + + iter_params = dict( + block_T=[128, 256, 512], + num_stages=[0, 1, 2, 4, 5], + threads=[32, 64, 128, 256, 512], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def tilelang_sparse_attention( + batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 +): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(block_T, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + + @T.prim_func + def tilelang_sparse_attention( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_shared([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)}) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return tilelang_sparse_attention + + +def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size): + """Generate random block indices for the benchmark.""" + block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda") + + for b in range(batch): + for t in range(seq_len): + for h in range(heads): + i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks] + block_indices[b, t, h, : len(i_i)] = i_i + + return block_indices.sort(-1)[0] + + +def benchmark_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): + """Benchmark the TileLang Sparse Attention implementation.""" + + # Set random seed for reproducibility + tilelang.testing.set_random_seed(0) + torch.random.manual_seed(0) + + # Compile the NSA kernel + kernel = tilelang_sparse_attention( + batch=batch_size, + heads=head_query, + seq_len=seq_len, + dim=dim, + is_causal=True, + block_size=block_size, + groups=head_query // heads, + selected_blocks=selected_blocks, + scale=scale, + ) + + profiler = kernel.get_profiler() + + profiler_latency = profiler.do_bench() + print(f"Profiler latency: {profiler_latency} ms") + + # Create input tensors + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + + # Generate block indices + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32) + + # Warmup + for _ in range(warmup): + kernel(Q, K, V, block_indices, out) + + # Synchronize before timing + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(iterations): + kernel(Q, K, V, block_indices, out) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate metrics + elapsed_time = end_time - start_time + avg_time = elapsed_time / iterations * 1000 # ms + + # Calculate FLOPs (approximate for NSA) + # Each token attends to selected_blocks * block_size tokens + # Each attention calculation involves 2*dim FLOPs for QK + # And another 2*dim FLOPs for attention * V + flops_per_token = 4 * dim * selected_blocks * block_size + total_flops = batch_size * seq_len * head_query * flops_per_token + flops_per_sec = total_flops / (elapsed_time / iterations) + tflops = flops_per_sec / 1e12 + + # Validate result against reference if requested + if validate: + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2) + if is_valid: + print("Validation: PASSED") + else: + print("Validation: FAILED") + print(f"Max difference: {(ref - out).abs().max().item()}") + + # Return benchmark results + return { + "avg_time_ms": avg_time, + "tflops": tflops, + "batch_size": batch_size, + "seq_len": seq_len, + "heads": heads, + "head_query": head_query, + "dim": dim, + "selected_blocks": selected_blocks, + "block_size": block_size, + } + + +def benchmark_triton_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): + """Benchmark the Triton-based TileLang Sparse Attention implementation.""" + + # Set random seed for reproducibility + tilelang.testing.set_random_seed(0) + torch.random.manual_seed(0) + + # Create input tensors + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + + # Generate block indices + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size) + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda") + + # Warmup + for _ in range(warmup): + out = parallel_nsa_fwd( + q=Q, + k=K, + v=V, + o_slc=o_slc, + o_swa=None, + lse_slc=lse_slc, + lse_swa=None, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=0, + scale=scale, + ) + + # Synchronize before timing + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(iterations): + out = parallel_nsa_fwd( + q=Q, + k=K, + v=V, + o_slc=o_slc, + o_swa=None, + lse_slc=lse_slc, + lse_swa=None, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=0, + scale=scale, + ) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate metrics + elapsed_time = end_time - start_time + avg_time = elapsed_time / iterations * 1000 # ms + + # Calculate FLOPs (approximate for NSA) + flops_per_token = 4 * dim * selected_blocks * block_size + total_flops = batch_size * seq_len * head_query * flops_per_token + flops_per_sec = total_flops / (elapsed_time / iterations) + tflops = flops_per_sec / 1e12 + + # Validate result against reference if requested + if validate: + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2) + if is_valid: + print("Validation: PASSED") + else: + print("Validation: FAILED") + print(f"Max difference: {(ref - out).abs().max().item()}") + + # Return benchmark results + return { + "avg_time_ms": avg_time, + "tflops": tflops, + "batch_size": batch_size, + "seq_len": seq_len, + "heads": heads, + "head_query": head_query, + "dim": dim, + "selected_blocks": selected_blocks, + "block_size": block_size, + } + + +def run_benchmark_suite(impl="all"): + """Run a suite of benchmarks with different configurations.""" + + # Define configurations to benchmark + configs = [ + # Small model config - Note: head_query must be a multiple of heads*16 for Triton + {"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32}, + # Medium model config + {"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64}, + # Large model config + {"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128}, + ] + + results = [] + for config in configs: + print(f"Running benchmark with config: {config}") + + if impl in ["all", "tilelang"]: + print("Benchmarking TileLang implementation:") + result = benchmark_nsa( + batch_size=config["batch_size"], + seq_len=config["seq_len"], + heads=config["heads"], + head_query=config["head_query"], + dim=config["dim"], + selected_blocks=config["selected_blocks"], + block_size=config["block_size"], + dtype=torch.float16, + scale=0.1, + validate=False, + ) + results.append({"impl": "tilelang", **result}) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") + + if impl in ["all", "triton"]: + print("Benchmarking Triton implementation:") + result = benchmark_triton_nsa( + batch_size=config["batch_size"], + seq_len=config["seq_len"], + heads=config["heads"], + head_query=config["head_query"], + dim=config["dim"], + selected_blocks=config["selected_blocks"], + block_size=config["block_size"], + dtype=torch.float16, + scale=0.1, + validate=False, + ) + results.append({"impl": "triton", **result}) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") + + if impl in ["all"]: + # Print comparison if both implementations were run + tilelang_result = next( + r + for r in results + if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) + triton_result = next( + r + for r in results + if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) + speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"] + print(f"Speedup (Triton vs TileLang): {speedup:.2f}x") + + print("-" * 50) + + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark TileLang Sparse Attention") + parser.add_argument("--batch", type=int, default=32, help="Batch size") + parser.add_argument("--seq_len", type=int, default=1024, help="Sequence length") + parser.add_argument("--heads", type=int, default=1, help="Number of heads") + parser.add_argument("--head_query", type=int, default=16, help="Number of query heads") + parser.add_argument("--dim", type=int, default=128, help="Head dimension") + parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") + parser.add_argument("--block_size", type=int, default=32, help="Block size") + parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)") + parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") + parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") + parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") + parser.add_argument("--validate", action="store_true", help="Validate against reference") + parser.add_argument("--suite", action="store_true", help="Run benchmark suite") + parser.add_argument( + "--impl", + type=str, + default="all", + choices=["tilelang", "triton", "all"], + help="Implementation to benchmark (tilelang, triton, or all)", + ) + + args = parser.parse_args() + + # For Triton impl, ensure head_query is a multiple of heads*16 + if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0: + # Adjust head_query to nearest valid value + args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16) + print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") + + if args.suite: + run_benchmark_suite(impl=args.impl) + else: + dtype = torch.float16 if args.dtype == T.float16 else torch.float32 + + if args.impl in ["tilelang", "all"]: + print("Benchmarking TileLang implementation:") + result = benchmark_nsa( + batch_size=args.batch, + seq_len=args.seq_len, + heads=args.heads, + head_query=args.head_query, + dim=args.dim, + selected_blocks=args.selected_blocks, + block_size=args.block_size, + dtype=dtype, + scale=args.scale, + warmup=args.warmup, + iterations=args.iterations, + validate=args.validate, + ) + print("\nBenchmark Results (TileLang):") + print( + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") + + if args.impl in ["triton", "all"]: + print("Benchmarking Triton implementation:") + result = benchmark_triton_nsa( + batch_size=args.batch, + seq_len=args.seq_len, + heads=args.heads, + head_query=args.head_query, + dim=args.dim, + selected_blocks=args.selected_blocks, + block_size=args.block_size, + dtype=dtype, + scale=args.scale, + warmup=args.warmup, + iterations=args.iterations, + validate=args.validate, + ) + print("\nBenchmark Results (Triton):") + print( + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") diff --git a/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..41f1dd86b99833d56b3164c865f55d6ae315311e --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -0,0 +1,865 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange +import tilelang + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def tilelang_kernel_fwd( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, +): + from tilelang import language as T + + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + o_slc_shape = [batch, seq_len, heads, dim] + lse_slc_shape = [batch, seq_len, heads] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 0 + threads = 32 + + @T.prim_func + def native_sparse_attention( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for k in T.Parallel(G): + scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for k in T.Parallel(G): + logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k] + T.copy(acc_s, acc_s_cast) + + # Rescale + for k, j in T.Parallel(G, BV): + acc_o[k, j] *= scores_scale[k] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy( + O_shared, + O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV], + ) + for i in T.Parallel(G): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G]) + + return native_sparse_attention + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def tilelang_kernel_bwd_dkv( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + dtype=T.float16, + accum_dtype=T.float32, +): + if scale is None: + sm_scale = (1.0 / dim) ** 0.5 + else: + sm_scale = scale + + scale = sm_scale * 1.44269504 + + from tilelang import language as T + + B = batch + BS = block_size + G = groups + V = dim + K = dim + BK = tilelang.next_power_of_2(K) + BV = min(128, tilelang.next_power_of_2(dim)) + NS = tilelang.cdiv(seq_len, BS) + NV = tilelang.cdiv(V, BV) + + heads_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + k_shape = [batch, seq_len, heads_kv, dim] + v_shape = [batch, seq_len, heads_kv, dim] + lse_slc_shape = [batch, seq_len, heads] + delta_slc_shape = [batch, seq_len, heads] + o_shape = [batch, heads, seq_len, dim] + do_slc_shape = [batch, seq_len, heads, dim] + dk_shape = [NV, batch, seq_len, heads_kv, dim] + dv_shape = [batch, seq_len, heads_kv, dim] + + block_mask_shape = [batch, seq_len, heads_kv, NS] + num_threads = 32 + print("NV", NV, "NS", NS, "B", B, "H", H) + + @T.prim_func + def flash_bwd_dkv( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), + ): + with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + Q_shared = T.alloc_shared([G, BK], dtype) + qkT = T.alloc_fragment([BS, G], accum_dtype) + qkT_cast = T.alloc_fragment([BS, G], dtype) + dsT = T.alloc_fragment([BS, G], accum_dtype) + dsT_cast = T.alloc_fragment([BS, G], dtype) + lse_shared = T.alloc_shared([G], accum_dtype) + delta = T.alloc_shared([G], accum_dtype) + + do = T.alloc_shared([G, BV], dtype) + dv = T.alloc_fragment([BS, BV], accum_dtype) + dk = T.alloc_fragment([BS, BK], accum_dtype) + dq = T.alloc_fragment([BS, G], accum_dtype) + + dv_shared = T.alloc_shared([BS, BV], dtype) + dk_shared = T.alloc_shared([BS, BK], dtype) + + i_b, i_h = i_bh // H, i_bh % H + + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) + + # [BS, BK] + T.clear(dk) + # [BS, BV] + T.clear(dv) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + loop_st = i_s * BS + loop_ed = seq_len + for i in T.Pipelined( + start=loop_st, + stop=loop_ed, + num_stages=0, + ): + b_m_slc = BlockMask[i_b, i, i_h, i_s] + if b_m_slc != 0: + # [G, BK] + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) + T.clear(qkT) + # [BS, BK] @ [G, BK] -> [BS, G] + T.gemm( + K_shared, + Q_shared, + qkT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + # [G] + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) + + # [G, BV] + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) + T.clear(dsT) + # [BS, BV] @ [G, BV] -> [BS, G] + T.gemm( + V_shared, + do, + dsT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(qkT, qkT_cast) + # [BS, G] @ [G, BV] -> [BS, BV] + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + # [G] + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) + for i, j in T.Parallel(BS, G): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + + # [BS, G] @ [G, BK] -> [BS, BK] + T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) + + return flash_bwd_dkv + + +def make_dq_layout(dQ): + from tilelang import language as T + + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout( + dQ.shape, + lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2], + ) + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def tilelang_kernel_bwd_dqkv( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + dtype=T.float16, + accum_dtype=T.float32, +): + if scale is None: + sm_scale = (1.0 / dim) ** 0.5 + else: + sm_scale = scale + + scale = sm_scale * 1.44269504 + + from tilelang import language as T + + B = batch + BS = block_size + G = groups + V = dim + K = dim + BK = tilelang.next_power_of_2(K) + BV = min(128, tilelang.next_power_of_2(dim)) + NS = tilelang.cdiv(seq_len, BS) + NV = tilelang.cdiv(V, BV) + + heads_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + k_shape = [batch, seq_len, heads_kv, dim] + v_shape = [batch, seq_len, heads_kv, dim] + lse_slc_shape = [batch, seq_len, heads] + delta_slc_shape = [batch, seq_len, heads] + o_shape = [batch, heads, seq_len, dim] + do_slc_shape = [batch, seq_len, heads, dim] + dq_shape = [NV, batch, seq_len, heads, dim] + dk_shape = [NV, batch, seq_len, heads_kv, dim] + dv_shape = [batch, seq_len, heads_kv, dim] + + block_mask_shape = [batch, seq_len, heads_kv, NS] + num_threads = 32 + + @T.prim_func + def flash_bwd_dqkv( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DQ: T.Tensor(dq_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), + ): + with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): + K_shared = T.alloc_shared([BS, BK], dtype) + dsT_shared = T.alloc_shared([BS, G], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + Q_shared = T.alloc_shared([G, BK], dtype) + qkT = T.alloc_fragment([BS, G], accum_dtype) + qkT_cast = T.alloc_fragment([BS, G], dtype) + dsT = T.alloc_fragment([BS, G], accum_dtype) + dsT_cast = T.alloc_fragment([BS, G], dtype) + lse_shared = T.alloc_shared([G], accum_dtype) + delta = T.alloc_shared([G], accum_dtype) + + do = T.alloc_shared([G, BV], dtype) + dv = T.alloc_fragment([BS, BV], accum_dtype) + dk = T.alloc_fragment([BS, BK], accum_dtype) + dq = T.alloc_fragment([G, BK], accum_dtype) + + dv_shared = T.alloc_shared([BS, BV], dtype) + dk_shared = T.alloc_shared([BS, BK], dtype) + + i_b, i_h = i_bh // H, i_bh % H + + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) + + # [BS, BK] + T.clear(dk) + # [BS, BV] + T.clear(dv) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + loop_st = i_s * BS + loop_ed = seq_len + for i in T.Pipelined( + start=loop_st, + stop=loop_ed, + num_stages=0, + ): + b_m_slc = BlockMask[i_b, i, i_h, i_s] + if b_m_slc != 0: + # [G, BK] + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) + T.clear(qkT) + # [BS, BK] @ [G, BK] -> [BS, G] + T.gemm( + K_shared, + Q_shared, + qkT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + # [G] + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) + + # [G, BV] + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) + T.clear(dsT) + # [BS, BV] @ [G, BV] -> [BS, G] + T.gemm( + V_shared, + do, + dsT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(qkT, qkT_cast) + # [BS, G] @ [G, BV] -> [BS, BV] + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + # [G] + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) + for _i, _j in T.Parallel(BS, G): + dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale + + # [BS, G] @ [G, BK] -> [BS, BK] + T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + # [BS, G] * [BS, BK] -> [G, BK] + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for _i, _j in T.Parallel(G, BK): + T.atomic_add(DQ[i_v, i_b, i, i_h * G + _i, _j], dq[_i, _j]) + + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) + + return flash_bwd_dqkv + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def tilelang_kernel_preprocess( + batch, + heads, + seq_len, + dim, + dtype=T.float16, + accum_dtype=T.float32, + blk=32, +): + from tilelang import language as T + + shape = [batch, seq_len, heads, dim] + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx]) + + return flash_bwd_prep + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def tilelang_kernel_block_mask( + batch, + heads, + seq_len, + selected_blocks, + block_size, + dtype=T.int32, +): + from tilelang import language as T + + block_indices_shape = [batch, seq_len, heads, selected_blocks] + block_counts_shape = [batch, seq_len, heads] + S = selected_blocks + BS = block_size + NS = tilelang.cdiv(seq_len, BS) + + block_mask_shape = [batch, seq_len, heads, NS] + USE_BLOCK_COUNTS = block_counts is not None + + @T.prim_func + def flash_bwd_block_mask( + BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore + BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore + BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore + ): + with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): + i_t, i_b, i_hs = bx, by, bz + i_h, i_s = i_hs // S, i_hs % S + b_i = BlockIndices[i_b, i_t, i_h, i_s] + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < BlockCounts[i_b, i_t, i_h].astype(i_s.dtype) + BlockMask[i_b, i_t, i_h, i_s] = b_m + else: + b_m = b_i * BS <= i_t + BlockMask[i_b, i_t, i_h, i_s] = b_m + + return flash_bwd_block_mask + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o_slc: torch.Tensor, + lse_slc: torch.Tensor, + do_slc: torch.Tensor, + o_swa: torch.Tensor, + lse_swa: torch.Tensor, + do_swa: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + window_size: int = 0, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + assert window_size == 0, "Window size is not supported yet" + delta_slc = tilelang_kernel_preprocess(B, HQ, T, K)(o_slc, do_slc) + + dq = torch.zeros(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool) + + fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv( + batch=B, + heads=HQ, + seq_len=T, + dim=K, + is_causal=True, + block_size=BS, + groups=G, + selected_blocks=S, + scale=scale, + ) + fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32)) + + dq = dq.sum(0) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + block_indices, + block_counts, + block_size, + window_size, + scale, + offsets, + ): + ctx.dtype = q.dtype + assert offsets is None, "Offsets are not supported yet" + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + B, SEQLEN, HQ, D = q.shape + H = k.shape[2] + G = HQ // H + S = block_indices.shape[-1] + V = v.shape[-1] + kernel = tilelang_kernel_fwd( + batch=B, + heads=HQ, + seq_len=SEQLEN, + dim=D, + is_causal=True, + scale=scale, + block_size=block_size, + groups=G, + selected_blocks=S, + ) + o_slc = torch.empty(B, SEQLEN, HQ, D, dtype=v.dtype, device=q.device) + lse_slc = torch.empty(B, SEQLEN, HQ, dtype=torch.float, device=q.device) + kernel(q, k, v, block_indices.to(torch.int32), o_slc, lse_slc) + + ctx.save_for_backward(q, k, v, o_slc, lse_slc) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), lse_slc.to(torch.float) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do_slc, do_swa): + q, k, v, o_slc, lse_slc = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=None, + lse_slc=lse_slc, + lse_swa=None, + do_slc=do_slc, + do_swa=do_swa, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + window_size=ctx.window_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices, + ) + return ( + dq.to(q), + dk.to(k), + dv.to(v), + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, SEQLEN, HQ, K]` if `head_first=False` else `[B, HQ, SEQLEN, K]`. + k (torch.Tensor): + keys of shape `[B, SEQLEN, H, K]` if `head_first=False` else `[B, H, SEQLEN, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, SEQLEN, H, V]` if `head_first=False` else `[B, H, SEQLEN, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, SEQLEN, HQ]` if `head_first=False` else `[B, HQ, SEQLEN]`. + g_swa (torch.Tensor): + Gate score for sliding attention of shape `[B, SEQLEN, HQ]` if `head_first=False` else `[B, HQ, SEQLEN]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, SEQLEN, H, S]` if `head_first=False` else `[B, H, SEQLEN, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, SEQLEN, H]` if `head_first=True` else `[B, SEQLEN, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + B, T, H, HQ, D, S, block_size, dtype = 1, 32, 1, 16, 32, 1, 32, torch.float16 + torch.random.manual_seed(0) + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_size=block_size, + block_counts=block_counts, + ) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + # assert_close(" o", ref, tri, 0.004) + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2) diff --git a/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eea58049b388ceca54c7f1883d1d5a4ab755a6 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -0,0 +1,176 @@ +# ruff: noqa +import torch +from reference import naive_nsa_simple_inference +import tilelang +from tilelang import language as T +import tilelang.testing + +tilelang.testing.set_random_seed(42) + + +# TODO(lei): workaround, as threads is not divisible by warp group size, +# auto warp specialization may have some bugs. +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def native_sparse_attention( + batch, + heads, + seq_len, # Length of K/V sequences (context window size) + dim, # Embedding dimension per head + scale=None, + block_size=64, # Tile size for attention computation + groups=1, # Grouped query attention (GQA) groups + selected_blocks=16, # Number of blocks to select per attention head +): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + # Modified shapes for inference (q has seq_len=1)a + q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 0 + threads = 32 + + @T.prim_func + def native_sparse_attention( + Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] + K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] + V: T.Tensor(kv_shape, dtype), # Same shape as K + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices + Output: T.Tensor(q_shape, dtype), # Output attention tensor + ): + with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): + # Shared memory allocations for tile storage + Q_shared = T.alloc_shared([G, BK], dtype) # Current query block + K_shared = T.alloc_shared([BS, BK], dtype) # Current key block + V_shared = T.alloc_shared([BS, BV], dtype) # Current value block + O_shared = T.alloc_shared([G, BV], dtype) # Output accumulator + + # Attention computation buffers + acc_s = T.alloc_fragment([G, BS], accum_dtype) # QK^T scores + acc_s_cast = T.alloc_fragment([G, BS], dtype) # Casted scores for softmax + acc_o = T.alloc_fragment([G, BV], accum_dtype) # Output accumulator + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_v, i_bh = by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + # Copy Q for the single position + T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0 + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # Main attention computation loop over selected blocks + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset + if i_s >= 0: # Skip invalid/padding blocks + # Load current key block to shared memory + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + # Compute QK^T attention scores + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Online softmax with numerical stability + # 1. Compute max for scaling + # 2. Compute exponentials and sum + # 3. Maintain running logsum for normalization + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Accumulate attention-weighted values + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Final normalization and output + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] # Normalize by logsum + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0 + + return native_sparse_attention + + +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 + groups = HQ // H + SEQ_LEN_Q = 1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + ) + + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN_Q): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") + + out = kernel(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa_simple_inference( + q=Q, + k=K, + v=V, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..ad36b10402429cdf24b87016c084b2790a0ba0eb --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -0,0 +1,175 @@ +# ruff: noqa +import torch +from reference import naive_nsa +import tilelang +from tilelang import language as T +import tilelang.testing + +tilelang.testing.set_random_seed(0) + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 2 + threads = 32 + + @T.prim_func + def native_sparse_attention( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return native_sparse_attention + + +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + ) + print(kernel.get_kernel_source()) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + out = kernel(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + print("out", out) + print("ref", ref) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..b52ebe42e210823de24107b9990cc111d5c8f1b3 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -0,0 +1,380 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import tilelang +from tilelang import language as T +import tilelang.testing + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from reference import naive_nsa +from einops import rearrange + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [c_seq_len, heads, dim] + kv_shape = [c_seq_len, head_kv, dim] + o_slc_shape = [c_seq_len, heads, dim] + o_swa_shape = [c_seq_len, heads, dim] + lse_slc_shape = [c_seq_len, heads] + lse_swa_shape = [c_seq_len, heads] + block_indices_shape = [c_seq_len, head_kv, selected_blocks] + block_counts_shape = [c_seq_len, head_kv] + offsets_shape = [batch + 1] + token_indices_shape = [c_seq_len, 2] + block_indices_dtype = T.int32 + block_counts_dtype = T.int32 + offsets_dtype = T.int32 + token_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 0 + threads = 32 + + @T.prim_func + def native_sparse_attention_varlen( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), + Offsets: T.Tensor(offsets_shape, offsets_dtype), + TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), + ): + with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_c, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + i_n, i_t = TokenIndices[i_c, 0], TokenIndices[i_c, 1] + + bos = Offsets[i_n] + eos = Offsets[i_n + 1] + current_seq_len = eos - bos + + NS = BlockCounts[i_t, i_h] + T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[bos + i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + # Lei: may have some padding issues + # we should learn from mha varlen templates to handle this + T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return native_sparse_attention_varlen + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + + batch = len(offsets) - 1 + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + + kernel = native_sparse_attention_varlen( + batch=batch, + heads=HQ, + c_seq_len=C_SEQ_LEN, + dim=K, + is_causal=True, + block_size=block_size, + groups=G, + selected_blocks=S, + ) + + o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) + kernel( + q.view(C_SEQ_LEN, HQ, D), + k.view(C_SEQ_LEN, H, D), + v.view(C_SEQ_LEN, H, D), + o_slc.view(C_SEQ_LEN, HQ, V), + block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), + block_counts.to(torch.int32).view(C_SEQ_LEN, H), + offsets.to(torch.int32), + token_indices.to(torch.int32), + ) + return o_slc + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + return o_slc.to(q.dtype) + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + assert False, "Window size is not supported yet" + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 + torch.manual_seed(42) + # randomly split the sequence into N segments + offsets = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.long), + torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]], + torch.tensor([C_SEQ_LEN], dtype=torch.long), + ], + 0, + ) + .cuda() + .sort()[0] + ) + + # seq-first required for inputs with variable lengths + perm_q = torch.randperm(C_SEQ_LEN, device="cuda") + perm_k = torch.randperm(C_SEQ_LEN, device="cuda") + perm_v = torch.randperm(C_SEQ_LEN, device="cuda") + q = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, HQ, D) + .clone() + .requires_grad_(True) + ) + k = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + v = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + token_indices = prepare_token_indices(offsets).tolist() + block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda") + for i in range(C_SEQ_LEN): + _, t = token_indices[i] + for h in range(H): + i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S] + block_indices[0, i, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + print("tri", tri) + print("ref", ref) + + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) diff --git a/tilelang/original/examples/deepseek_nsa/example_triton_nsa_bwd.py b/tilelang/original/examples/deepseek_nsa/example_triton_nsa_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..af05bfa701654e3ec2dd53ffb2c0b50c61514801 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_triton_nsa_bwd.py @@ -0,0 +1,1008 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton +import triton.language as tl + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + # if USE_BLOCK_COUNTS: + # NS = tl.load(block_counts + (bos + i_t) * H + i_h) + # else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do_slc, do_swa): + q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + do_slc=do_slc, + do_swa=do_swa, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + window_size=ctx.window_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices, + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None + lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse_slc, + lse_swa, + delta_slc, + delta_swa, + do_slc, + do_swa, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BS, BK], dtype=tl.float32) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BS, BV], dtype=tl.float32) + + for i in range(i_s * BS, T): + b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s) + if b_m_slc: + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1)) + # [G] + b_lse_slc = tl.load(p_lse_slc) + b_delta_slc = tl.load(p_delta_slc) + # [BS, G] + b_s_slc = tl.dot(b_k, tl.trans(b_q)) + b_p_slc = tl.exp(b_s_slc - b_lse_slc[None, :]) + b_p_slc = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p_slc, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p_slc.to(b_do_slc.dtype), b_do_slc) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp_slc = tl.dot(b_v, tl.trans(b_do_slc)) + # [BS, G] + b_ds_slc = b_p_slc * (b_dp_slc - b_delta_slc[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds_slc.to(b_q.dtype), b_q) + + if WS > 0: + o_s = i_s * BS + tl.arange(0, BS) + if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1): + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1)) + # [G] + b_lse_swa = tl.load(p_lse_swa) + b_delta_swa = tl.load(p_delta_swa) + # [BS, G] + b_s_swa = tl.dot(b_k, tl.trans(b_q)) + b_p_swa = tl.exp(b_s_swa - b_lse_swa[None, :]) + b_p_swa = tl.where((i >= o_s and (i - WS) < o_s)[:, None], b_p_swa, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p_swa.to(b_do_swa.dtype), b_do_swa) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp_swa = tl.dot(b_v, tl.trans(b_do_swa)) + # [BS, G] + b_ds_swa = b_p_swa * (b_dp_swa - b_delta_swa[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds_swa.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)}) +@triton.jit +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_s = i_hs // S, i_hs % S + + b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s) + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h) + else: + b_m = b_i * BS <= i_t + + if b_i < NS and b_i >= 0: + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse_slc, + delta_slc, + do_slc, + lse_swa, + delta_swa, + do_swa, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (bos + i_t) * HQ * K + do_slc += (bos + i_t) * HQ * V + lse_slc += (bos + i_t) * HQ + delta_slc += (bos + i_t) * HQ + if WS > 0: + do_swa += (bos + i_t) * HQ * V + lse_swa += (bos + i_t) * HQ + delta_swa += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ * K + block_indices += (bos + i_t) * H * S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do_slc = tl.make_block_ptr(do_slc, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + i_h * G + tl.arange(0, G) + p_delta_slc = delta_slc + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1)) + # [G] + b_lse_slc = tl.load(p_lse_slc) + b_delta_slc = tl.load(p_delta_slc) + + # [G, BK] + b_dq_slc = tl.zeros([G, BK], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BV, BS] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_p_slc = tl.exp(b_s_slc - b_lse_slc[:, None]) + b_p_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_slc, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp_slc = tl.dot(b_do_slc, b_v_slc) + b_ds_slc = b_p_slc * (b_dp_slc.to(tl.float32) - b_delta_slc[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq_slc += tl.dot(b_ds_slc.to(b_k_slc.dtype), tl.trans(b_k_slc)) + b_dq_slc *= scale + + if WS > 0: + p_do_swa = tl.make_block_ptr(do_swa, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + i_h * G + tl.arange(0, G) + p_delta_swa = delta_swa + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1)) + # [G] + b_lse_swa = tl.load(p_lse_swa) + b_delta_swa = tl.load(p_delta_swa) + + # [G, BK] + b_dq_swa = tl.zeros([G, BK], dtype=tl.float32) + for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): + p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_swa = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1)) + # [BV, BS] + b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) + + # [G, BS] + b_s_swa = tl.dot(b_q, b_k_swa) + b_p_swa = tl.exp(b_s_swa - b_lse_swa[:, None]) + b_p_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_swa, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp_swa = tl.dot(b_do_swa, b_v_swa) + b_ds_swa = b_p_swa * (b_dp_swa.to(tl.float32) - b_delta_swa[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq_swa += tl.dot(b_ds_swa.to(b_k_swa.dtype), tl.trans(b_k_swa)) + b_dq_swa *= scale + + if WS == 0: + tl.store(p_dq, b_dq_slc.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + else: + tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + if WS > 0: + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_swa = tl.zeros([G, BV], dtype=tl.float32) + + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_swa = tl.zeros([G], dtype=tl.float32) + for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): + p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1)) + # [BS, BV] + b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) + # [G, BS] + b_s_swa = tl.dot(b_q, b_k_swa) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) + + # [G] + b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa + b_r_swa = tl.exp(b_mp_swa - b_m_swa) + # [G, BS] + b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None]) + # [G] + b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1) + # [G, BV] + b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa) + + b_mp_swa = b_m_swa + b_o_swa = b_o_swa / b_acc_swa[:, None] + b_m_swa += tl.log(b_acc_swa) + tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty)) + + +@triton.jit +def parallel_nsa_bwd_kernel_preprocess(o, do, delta, B: tl.constexpr, V: tl.constexpr): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +def parallel_nsa_block_mask( + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + offsets: torch.LongTensor, + block_size: int, +): + B, T, H, S = block_indices.shape + BS = block_size + if offsets is not None: + NS = triton.cdiv(prepare_lens(offsets).max().item(), BS) + else: + NS = triton.cdiv(T, BS) + block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) + + parallel_nsa_kernel_mask[(T, B, H * S)]( + block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS + ) + return block_mask + + +def parallel_nsa_bwd_preprocess(o: torch.Tensor, do: torch.Tensor): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o_slc: torch.Tensor, + lse_slc: torch.Tensor, + do_slc: torch.Tensor, + o_swa: torch.Tensor, + lse_swa: torch.Tensor, + do_swa: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + window_size: int = 0, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + delta_slc = parallel_nsa_bwd_preprocess(o_slc, do_slc) + delta_swa = parallel_nsa_bwd_preprocess(o_swa, do_swa) if window_size > 0 else None + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse_slc=lse_slc, + delta_slc=delta_slc, + do_slc=do_slc, + lse_swa=lse_swa, + delta_swa=delta_swa, + do_swa=do_swa, + dq=dq, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + dq = dq.sum(0) + + if offsets is not None: + chunk_indices = prepare_chunk_indices(offsets, BS) + NS = len(chunk_indices) + else: + chunk_indices = None + NS = triton.cdiv(T, BS) + + # [B, T, H, M] + block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size) + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NS, B * H) + parallel_nsa_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse_slc=lse_slc, + lse_swa=lse_swa, + delta_slc=delta_slc, + delta_swa=delta_swa, + do_slc=do_slc, + do_swa=do_swa, + dk=dk, + dv=dv, + block_mask=block_mask, + offsets=offsets, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + M=block_mask.shape[-1], + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do_slc, do_swa): + q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + do_slc=do_slc, + do_swa=do_swa, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + window_size=ctx.window_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices, + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 + torch.random.manual_seed(0) + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_size=block_size, + block_counts=block_counts, + ) + print("tri", tri) + print("ref", ref) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + # assert_close(" o", ref, tri, 0.004) + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2) diff --git a/tilelang/original/examples/deepseek_nsa/example_triton_nsa_fwd.py b/tilelang/original/examples/deepseek_nsa/example_triton_nsa_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ab28daaf931ebc7565130343f8e4c15a1570d2 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_triton_nsa_fwd.py @@ -0,0 +1,357 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton +import triton.language as tl + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + # if USE_BLOCK_COUNTS: + # NS = tl.load(block_counts + (bos + i_t) * H + i_h) + # else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None + lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 + torch.random.manual_seed(0) + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_size=block_size, + block_counts=block_counts, + ) + + print("tri", tri) + print("ref", ref) + + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) diff --git a/tilelang/original/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py b/tilelang/original/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4eb6d7ba6119a0ebf16700d65b55b1fd1a237b --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py @@ -0,0 +1,392 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton +import triton.language as tl + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + if WS > 0: + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_swa = tl.zeros([G, BV], dtype=tl.float32) + + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_swa = tl.zeros([G], dtype=tl.float32) + for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): + p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1)) + # [BS, BV] + b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) + # [G, BS] + b_s_swa = tl.dot(b_q, b_k_swa) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) + + # [G] + b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa + b_r_swa = tl.exp(b_mp_swa - b_m_swa) + # [G, BS] + b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None]) + # [G] + b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1) + # [G, BV] + b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa) + + b_mp_swa = b_m_swa + b_o_swa = b_o_swa / b_acc_swa[:, None] + b_m_swa += tl.log(b_acc_swa) + tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty)) + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None + lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 + torch.manual_seed(42) + # randomly split the sequence into N segments + offsets = ( + torch.cat( + [torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)], + 0, + ) + .cuda() + .sort()[0] + ) + # offsets.shape is [N+1] + # seq-first required for inputs with variable lengths + perm_q = torch.randperm(T, device="cuda") + perm_k = torch.randperm(T, device="cuda") + perm_v = torch.randperm(T, device="cuda") + q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda") + + token_indices = prepare_token_indices(offsets).tolist() + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda") + for i in range(T): + _, t = token_indices[i] + for h in range(H): + i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] + block_indices[0, i, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + print("tri", tri) + print("ref", ref) + + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) diff --git a/tilelang/original/examples/deepseek_nsa/reference.py b/tilelang/original/examples/deepseek_nsa/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..58083108eb30e871fba15b60a9f36bacee9c3949 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/reference.py @@ -0,0 +1,305 @@ +# ruff: noqa +from typing import Optional + +import torch +from typing import Union +from einops import rearrange, repeat + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None`. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + S = block_indices.shape[-1] + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + if isinstance(block_counts, torch.Tensor): + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o_slc = torch.zeros_like(v) + o_swa = torch.zeros_like(v) if window_size > 0 else None + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[i] + else: + s_b = block_counts + else: + T = cu_seqlens[i + 1] - cu_seqlens[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] + else: + s_b = block_counts + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [HQ] + g_slc_i = g_slc_b[i_q] + # [HQ] + g_swa_i = g_swa_b[i_q] + # [S*BS, HQ] + i_i = i_b[i_q] + # [HQ] + if isinstance(block_counts, torch.Tensor): + s_i = s_b[i_q] + else: + s_i = s_b + # [S*BS, HQ, -1] + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) + if not varlen: + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + else: + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + if window_size > 0: + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) + if not varlen: + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + else: + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + + if head_first: + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") + + return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) + + +def naive_nsa_simple( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: torch.LongTensor, + block_size: int = 64, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (torch.LongTensor): + Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + block_size (int): + Selected block size. Default: 64. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + scale = k.shape[-1] ** -0.5 + + dtype = q.dtype + HQ = q.shape[2] + H = k.shape[2] + D = k.shape[-1] + G = HQ // H + BS = block_size + S = block_indices.shape[-1] + SELECTED_BLOCKS_SIZE = S * BS + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + o = torch.zeros_like(v) + B, T = q.shape[:2] + + for i in range(B): + q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i] + # [T, HQ, S, BS] -> [T, HQ, S*BS] + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, HQ, S*BS] -> [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [S*BS, HQ] -> represents selected blocks for each query token + i_i = i_b[i_q] + # [HQ] -> represents the number of selected blocks for each query token + s_i = s_b[i_q] + + k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype) + v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype) + + for h in range(HQ): + for t in range(SELECTED_BLOCKS_SIZE): + selected_block_index = i_i[t, h] + k_i[t, h] = k_b[selected_block_index, h, :] + v_i[t, h] = v_b[selected_block_index, h, :] + + # [S*BS, HQ] + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf")) + attn = torch.softmax(attn, dim=0) + o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i) + + return o.to(dtype) + + +def naive_nsa_simple_inference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: torch.LongTensor, + block_size: int = 64, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (torch.LongTensor): + Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`. + block_size (int): + Selected block size. Default: 64. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + scale = k.shape[-1] ** -0.5 + + dtype = q.dtype + HQ = q.shape[2] + H = k.shape[2] + D = k.shape[-1] + G = HQ // H + BS = block_size + S = block_indices.shape[-1] + SELECTED_BLOCKS_SIZE = S * BS + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + o = torch.zeros_like(q) + B, T = q.shape[:2] + + for i in range(B): + q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i] + # [T, HQ, S, BS] -> [T, HQ, S*BS] + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, HQ, S*BS] -> [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + + # [HQ, D] + q_i = q_b[0] * scale + # [S*BS, HQ] -> represents selected blocks for each query token + i_i = i_b[0] + # [HQ] -> represents the number of selected blocks for each query token + s_i = s_b[0] + + k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype) + v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype) + + for h in range(HQ): + for t in range(SELECTED_BLOCKS_SIZE): + selected_block_index = i_i[t, h] + k_i[t, h] = k_b[selected_block_index, h, :] + v_i[t, h] = v_b[selected_block_index, h, :] + + # [S*BS, HQ] + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((c >= s_i), float("-inf")) + attn = torch.softmax(attn, dim=0) + o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i) + + return o.to(dtype) diff --git a/tilelang/original/examples/deepseek_nsa/requirements.txt b/tilelang/original/examples/deepseek_nsa/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..777c2ad4c81bbf9c00a4fca8361c7dd9dfb39d0e --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e \ No newline at end of file diff --git a/tilelang/original/examples/deepseek_nsa/test_example_tilelang_nsa.py b/tilelang/original/examples/deepseek_nsa/test_example_tilelang_nsa.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc6f98e962c167bf3f2783910c7dda9bd624373 --- /dev/null +++ b/tilelang/original/examples/deepseek_nsa/test_example_tilelang_nsa.py @@ -0,0 +1,17 @@ +# ruff: noqa +import tilelang.testing + +from example_tilelang_nsa_fwd import main as main_fwd +from example_tilelang_nsa_decode import main as main_fwd_decode + + +def test_example_tilelang_nsa_fwd(): + main_fwd() + + +def test_example_tilelang_nsa_fwd_decode(): + main_fwd_decode() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/deepseek_v32/README.md b/tilelang/original/examples/deepseek_v32/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8457745b0e9ea4aea5e5f06c5417859cb51a41bc --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/README.md @@ -0,0 +1,223 @@ +## Directory Structure + +``` +deepseek_v32/ +├── README.md # This file +├── figures/ # Figures and diagrams +├── inference/ # Inference implementation folder +├── fp8_lighting_indexer.py # FP8 lighting indexer +├── sparse_mla_bwd.py # Sparse MLA backward implementation +├── sparse_mla_fwd.py # Sparse MLA forward implementation +├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass +├── topk_selector.py # Top-k selector implementation +``` + +## File Descriptions + +### Architecture Overview + +![DeepSeek V3.2 Architecture](./figures/v32_arch.png) + +The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations: + +1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision +2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation +3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes + +### Lightning Indexer + +Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector. + +The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices: + +```python +T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, +) +``` + +After the matmul, we apply ReLU and aggregate across heads with learned weights: + +```python +for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = ( + T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i] + ) * index_k_scale_fragment[bn_i] + +T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) +``` + +The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions: + +```python +for bq_i in T.serial(block_Q): + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) +for bq_i in T.serial(block_Q): + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) +``` + +The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training. + +### Top-k Selector + +The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process. + +The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence: + +```python +for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + input_idx = s*BLOCK_SIZE+tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + inval_int16 = convert_to_uint16(input[bx, input_idx]) + T.atomic_add(s_histogram[inval_int16], 1) +``` + +The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin: + +```python +if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx +``` + +Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing: + +```python +if l_bin_id32 > l_threshold_bin_id: + pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True) + index[bx, pos] = input_idx +elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + pos = T.atomic_add(s_num_input[0], 1, return_prev=True) + s_input_idx[0, pos] = input_idx +``` + +Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence. + +### Sparse MLA Forward + +The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens. + +Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions: + +```python +# Dense MLA: iterate over full sequence +loop_range = T.ceildiv(seqlen_kv, block_N) +for k in T.Pipelined(loop_range, num_stages=2): + T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) + # ... compute attention over this block +``` + +Sparse MLA only loads KV positions selected by the top-k selector: + +```python +# Sparse MLA: iterate over selected indices only +for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] + # ... compute attention over selected tokens +``` + +This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid: + +```python +for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i +``` + +Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA. + +### Sparse MLA Forward (Pipelined) + +The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines. + +The key difference is splitting the warp groups into specialized roles: + +```python +if tx < 128: + # Consumer 0: computes left half of output (D//2 dimensions) + # Handles QK matmul, softmax, and PV for left half + +elif tx >= 128 and tx < 256: + # Consumer 1: computes right half of output (D//2 dimensions) + # Only does PV matmul for right half + +elif tx >= 256: + # Producer: loads KV data from global memory + # Uses async copy with barriers to feed consumers +``` + +The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed: + +```python +# Producer alternates between two buffers +for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + # ... load KV into buffer 0 + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + # ... load KV into buffer 1 + T.cp_async_barrier_noinc(bar_k_1_ready[0]) +``` + +Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul. + +### Sparse MLA Backward + +The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity. + +The backward pass consists of three main stages: + +**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient): + +```python +for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] +T.reduce_sum(acc, delta, 1) +``` + +**2. Main Backward Computation**: Computes gradients through sparse attention: + +```python +# Sparse MLA backward: iterate over selected indices only +for i_i in T.Pipelined(NI, num_stages=num_stages): + # Load KV data for selected indices + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i] + + # Recompute attention scores for backward + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + # Apply softmax gradient: dP = P * (dP_raw - Delta) + for h_i, bi_i in T.Parallel(padded_H, BI): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale +``` + +The key gradient computations are: +- **dQ = dP @ K** (query gradients) +- **dK = dP^T @ Q** (key gradients) +- **dV = P^T @ dO** (value gradients) + +**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation: + +```python +# Atomically update dKV at selected indices +for bi_i, d_i in T.Parallel(BI // split_store, D // 4): + T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4]) +``` + +**Performance**: The sparse MLA backward achieves excellent performance: +- **H800 SXM**: ~100 TFlops +- **H200 SXM**: ~115 TFlops + +The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization. diff --git a/tilelang/original/examples/deepseek_v32/figures/v32_arch.png b/tilelang/original/examples/deepseek_v32/figures/v32_arch.png new file mode 100644 index 0000000000000000000000000000000000000000..50f3a847b509868c7b04af20e1edb81b54bc6bb6 Binary files /dev/null and b/tilelang/original/examples/deepseek_v32/figures/v32_arch.png differ diff --git a/tilelang/original/examples/deepseek_v32/fp8_lighting_indexer.py b/tilelang/original/examples/deepseek_v32/fp8_lighting_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..01ad0a73469b7b7feff0a58f7918d3d144ec3c19 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/fp8_lighting_indexer.py @@ -0,0 +1,284 @@ +# ruff: noqa +import itertools +import tilelang +from tilelang import language as T +import torch +from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8 + + +def display_error_message(msg): + print(f"\033[31mWARNING: {msg}\033[0m") + + +def compute_correlation(a, b, label="tensor"): + a, b = a.data.double(), b.data.double() + norm_sum = (a * a + b * b).sum() + if norm_sum == 0: + display_error_message(f"{label} all zero") + return 1 + correlation = 2 * (a * b).sum() / norm_sum + return correlation + + +def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True): + a_finite = torch.isfinite(a) + b_finite = torch.isfinite(b) + if not torch.all(a_finite == b_finite): + display_error_message(f"{tensor_name} Error: isfinite mask mismatch") + if should_raise: + assert False + if not torch.isclose( + a.masked_fill(a_finite, 0), + b.masked_fill(b_finite, 0), + rtol=0, + atol=0, + equal_nan=True, + ).all(): + display_error_message(f"{tensor_name} Error: nonfinite value mismatch") + if should_raise: + assert False + a = a.masked_fill(~a_finite, 0) + b = b.masked_fill(~b_finite, 0) + correlation = compute_correlation(a, b, tensor_name) + difference = 1.0 - correlation + if not (0 <= difference <= tolerance): + display_error_message(f"{tensor_name} Error: {difference}") + if should_raise: + assert False + return difference + + +def get_configs(): + iter_params = dict( + block_N=[32, 64, 128], + num_stages=[0, 1, 2], + threads=[128, 256], + block_Q=[1, 2, 4], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +class SupplyProg: + def __init__(self): + self.tensors_dict = {} + + def get_key(self, shape, dtype) -> str: + return f"{shape}-{dtype}" + + def supply_prog(self, params): + shapes = [p.shape for p in params] + dtypes = [p.dtype for p in params] + tensor_list = [] + for shape, dtype in zip(shapes, dtypes): + key = self.get_key(shape, dtype) + if key not in self.tensors_dict: + self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda") + tensor_list.append(self.tensors_dict[key]) + else: + tensor_list.append(self.tensors_dict[key]) + return tensor_list + + +supply_prog = SupplyProg() + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def mqa_attn_return_logits( + heads, + index_dim, + block_N=256, + num_stages=3, + threads=512, + block_Q=None, +): + if block_Q is None: + block_Q = 128 // heads + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 + + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + index_q_shape = [seq_len * heads, index_dim] + index_k_shape = [seq_len_kv, index_dim] + index_k_scale_shape = [seq_len_kv] + logits_shape = [seq_len, seq_len_kv] + + @T.prim_func + def mqa_attn_return_logits_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], dtype) + index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_local([1], index_dtype) + cu_k_e_max = T.alloc_local([1], index_dtype) + + cu_k_s_min[0] = 2147483647 + cu_k_e_max[0] = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) + for bq_i in T.serial(block_Q): + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): + T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[ + bn_i + ] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] + + return mqa_attn_return_logits_kernel + + +@tilelang.jit +def clean_logits_( + threads: int = 512, + block_K: int = 4096, +): + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + dtype = T.float + indices_dtype = T.int32 + + @T.prim_func + def clean_logits_kernel( + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + ): + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = T.alloc_local([1], indices_dtype) + cu_k_e = T.alloc_local([1], indices_dtype) + cu_k_s[0] = CuSeqLenKS[bx] + cu_k_e[0] = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx < cu_k_s[0] or idx >= cu_k_e[0]: + Logits[bx, idx] = -T.infinity(dtype) + + return clean_logits_kernel + + +def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): + seq_len, heads, index_dim = q.shape + seq_len_kv = kv.shape[0] + + clean_logits_kernel = clean_logits_() + + mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim) + logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) + mqa_attn_return_logits_kernel( + q.view(seq_len * heads, index_dim), + kv, + kv_scales, + logits, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) + if clean_logits: + clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits + + +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): + k = kv + q = q.float() + k = k.float() + + seq_len_kv = kv.shape[0] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + + +def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + # initial random seed to make the performance reproducible + torch.manual_seed(0) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) + + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + + print(f"diff: {diff}") + + from tilelang.profiler import do_bench + + def logits_fn(): + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + logits_fn() + + print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) + + logits_ms = do_bench(logits_fn, warmup=100, rep=100) + logits_flops = 2 * cost_ref * H * D + logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12 + print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}") + print(f"cost_ref: {cost_ref}") + + +if __name__ == "__main__": + test_fp8_lighting_indexer() diff --git a/tilelang/original/examples/deepseek_v32/inference/README.md b/tilelang/original/examples/deepseek_v32/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fe4cc21bba684273345706317f49012f4bb96d71 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/README.md @@ -0,0 +1,14 @@ +# DeepSeek V3.2 + +First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count: +```bash +cd inference +export EXPERTS=256 +python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP} +``` + +Launch the interactive chat interface and start exploring DeepSeek's capabilities: +```bash +export CONFIG=config_671B_v3.2.json +torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive +``` \ No newline at end of file diff --git a/tilelang/original/examples/deepseek_v32/inference/config_671B_v3.2.json b/tilelang/original/examples/deepseek_v32/inference/config_671B_v3.2.json new file mode 100644 index 0000000000000000000000000000000000000000..be88f1cca20c7dc78d8459c4c8456c197cba0b5a --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/config_671B_v3.2.json @@ -0,0 +1,26 @@ +{ + "vocab_size": 129280, + "dim": 7168, + "inter_dim": 18432, + "moe_inter_dim": 2048, + "n_layers": 61, + "n_dense_layers": 3, + "n_heads": 128, + "n_routed_experts": 256, + "n_shared_experts": 1, + "n_activated_experts": 8, + "n_expert_groups": 8, + "n_limited_groups": 4, + "route_scale": 2.5, + "score_func": "sigmoid", + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "dtype": "fp8", + "scale_fmt": "ue8m0", + "index_n_heads": 64, + "index_head_dim": 128, + "index_topk": 2048 +} \ No newline at end of file diff --git a/tilelang/original/examples/deepseek_v32/inference/convert.py b/tilelang/original/examples/deepseek_v32/inference/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..df7943918f80557af7a0485b7d3591d070ffcbab --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/convert.py @@ -0,0 +1,100 @@ +import os +import shutil +from argparse import ArgumentParser +from glob import glob +from tqdm import tqdm, trange + +import torch +from safetensors.torch import safe_open, save_file + +mapping = { + "embed_tokens": ("embed", 0), + "input_layernorm": ("attn_norm", None), + "post_attention_layernorm": ("ffn_norm", None), + "q_proj": ("wq", 0), + "q_a_proj": ("wq_a", None), + "q_a_layernorm": ("q_norm", None), + "q_b_proj": ("wq_b", 0), + "kv_a_proj_with_mqa": ("wkv_a", None), + "kv_a_layernorm": ("kv_norm", None), + "kv_b_proj": ("wkv_b", 0), + "o_proj": ("wo", 1), + "gate": ("gate", None), + "gate_proj": ("w1", 0), + "down_proj": ("w2", 1), + "up_proj": ("w3", 0), + "norm": ("norm", None), + "lm_head": ("head", 0), + "scale": ("scale", None), + "wq_b": ("wq_b", None), + "wk": ("wk", None), + "k_norm": ("k_norm", None), + "weights_proj": ("weights_proj", None), +} + + +def main(hf_ckpt_path, save_path, n_experts, mp): + """ + Converts and saves model checkpoint files into a specified format. + + Args: + hf_ckpt_path (str): Path to the directory containing the input checkpoint files. + save_path (str): Path to the directory where the converted checkpoint files will be saved. + n_experts (int): Total number of experts in the model. + mp (int): Model parallelism factor. + + Returns: + None + """ + torch.set_num_threads(8) + n_local_experts = n_experts // mp + state_dicts = [{} for _ in range(mp)] + + for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + with safe_open(file_path, framework="pt", device="cpu") as f: + for name in f.keys(): + if "model.layers.61" in name: + continue + param: torch.Tensor = f.get_tensor(name) + if name.startswith("model."): + name = name[len("model."):] + name = name.replace("self_attn", "attn") + name = name.replace("mlp", "ffn") + name = name.replace("weight_scale_inv", "scale") + name = name.replace("e_score_correction_bias", "bias") + key = name.split(".")[-2] + assert key in mapping, f"Key {key} not found in mapping" + new_key, dim = mapping[key] + name = name.replace(key, new_key) + for i in range(mp): + new_param = param + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + continue + elif dim is not None: + assert param.size( + dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" + shard_size = param.size(dim) // mp + new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + state_dicts[i][name] = new_param + + os.makedirs(save_path, exist_ok=True) + + for i in trange(mp): + save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) + + for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): + new_file_path = os.path.join(save_path, os.path.basename(file_path)) + shutil.copyfile(file_path, new_file_path) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--hf-ckpt-path", type=str, required=True) + parser.add_argument("--save-path", type=str, required=True) + parser.add_argument("--n-experts", type=int, required=True) + parser.add_argument("--model-parallel", type=int, required=True) + args = parser.parse_args() + assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" + main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) diff --git a/tilelang/original/examples/deepseek_v32/inference/generate.py b/tilelang/original/examples/deepseek_v32/inference/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..fda1e80968dc610574f576886e2b33c6ee24e56f --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/generate.py @@ -0,0 +1,197 @@ +import os +import json +from argparse import ArgumentParser +from typing import List + +import torch +import torch.distributed as dist +from transformers import AutoTokenizer +from safetensors.torch import load_model + +from model import Transformer, ModelArgs + + +def sample(logits, temperature: float = 1.0): + """ + Samples a token from the logits using temperature scaling. + + Args: + logits (torch.Tensor): The logits tensor for token predictions. + temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. + + Returns: + torch.Tensor: The sampled token. + """ + logits = logits / max(temperature, 1e-5) + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + + +@torch.inference_mode() +def generate(model: Transformer, + prompt_tokens: List[List[int]], + max_new_tokens: int, + eos_id: int, + temperature: float = 1.0) -> List[List[int]]: + """ + Generates new tokens based on the given prompt tokens using the specified model. + + Args: + model (Transformer): The transformer model used for token generation. + prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. + max_new_tokens (int): The maximum number of new tokens to generate. + eos_id (int): The end-of-sequence token ID. + temperature (float, optional): The temperature value for sampling. Defaults to 1.0. + + Returns: + List[List[int]]: A list of lists containing the generated tokens for each sequence. + """ + prompt_lens = [len(t) for t in prompt_tokens] + assert max( + prompt_lens + ) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" + total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) + tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") + for i, t in enumerate(prompt_tokens): + tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + prev_pos = 0 + finished = torch.tensor([False] * len(prompt_tokens), device="cuda") + prompt_mask = tokens != -1 + for cur_pos in range(min(prompt_lens), total_len): + logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + next_token = sample(logits, temperature) + else: + next_token = logits.argmax(dim=-1) + next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) + prev_pos = cur_pos + if finished.all(): + break + completion_tokens = [] + for i, toks in enumerate(tokens.tolist()): + toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens] + if eos_id in toks: + toks = toks[:toks.index(eos_id)] + completion_tokens.append(toks) + return completion_tokens + + +def main( + ckpt_path: str, + config: str, + input_file: str = "", + interactive: bool = True, + max_new_tokens: int = 100, + temperature: float = 1.0, +) -> None: + """ + Main function to load the model and perform interactive or batch text generation. + + Args: + ckpt_path (str): Path to the model checkpoint directory. + config (str): Path to the model configuration file. + input_file (str, optional): Path to a file containing input prompts. Defaults to "". + interactive (bool, optional): Whether to run in interactive mode. Defaults to True. + max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. + temperature (float, optional): Temperature for sampling. Defaults to 1.0. + """ + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + if world_size > 1: + dist.init_process_group("nccl") + global print + if rank != 0: + print = lambda *_, **__: None + torch.cuda.set_device(local_rank) + torch.set_default_dtype(torch.bfloat16) + torch.set_num_threads(8) + torch.manual_seed(33377335) + with open(config) as f: + args = ModelArgs(**json.load(f)) + print(args) + with torch.device("cuda"): + model = Transformer(args) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) + print("load model") + load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) + print("I'm DeepSeek 👋") + + if interactive: + messages = [] + while True: + if world_size == 1: + prompt = input(">>> ") + elif rank == 0: + prompt = input(">>> ") + objects = [prompt] + dist.broadcast_object_list(objects, 0) + else: + objects = [None] + dist.broadcast_object_list(objects, 0) + prompt = objects[0] + if prompt == "/exit": + break + elif prompt == "/clear": + messages.clear() + continue + messages.append({"role": "user", "content": prompt}) + prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + completion_tokens = generate(model, [prompt_tokens], max_new_tokens, + tokenizer.eos_token_id, temperature) + completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) + print(completion) + messages.append({"role": "assistant", "content": completion}) + else: + with open(input_file) as f: + prompts = f.read().split("\n\n") + assert len( + prompts + ) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" + prompt_tokens = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True) for prompt in prompts + ] + completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, + temperature) + completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + for prompt, completion in zip(prompts, completions): + print("Prompt:", prompt) + print("Completion:", completion) + print() + + if world_size > 1: + dist.destroy_process_group() + + +if __name__ == "__main__": + """ + Command-line interface for distributed text generation. + + Arguments: + --ckpt-path (str): Path to the model checkpoint directory. + --config (str): Path to the model configuration file. + --input-file (str, optional): File containing prompts for batch processing. + --interactive (bool, optional): Enable interactive mode for generating text. + --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. + --temperature (float, optional): Temperature for sampling. Defaults to 0.2. + + Raises: + AssertionError: If neither input-file nor interactive mode is specified. + """ + parser = ArgumentParser() + parser.add_argument("--ckpt-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--input-file", type=str, default="") + parser.add_argument("--interactive", action="store_true") + parser.add_argument("--max-new-tokens", type=int, default=200) + parser.add_argument("--temperature", type=float, default=0.6) + args = parser.parse_args() + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" + main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, + args.temperature) diff --git a/tilelang/original/examples/deepseek_v32/inference/kernel.py b/tilelang/original/examples/deepseek_v32/inference/kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..25abf15d597caea97d0a890ca09a9cf73c7aa084 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/kernel.py @@ -0,0 +1,268 @@ +import torch +import tilelang +import tilelang.language as T +from typing import Tuple, Optional + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, +} + +FP8 = T.float8_e4m3fn +BF16 = T.bfloat16 +FP32 = T.float32 + + +def fast_log2_ceil(x): + bits_x = T.reinterpret(T.uint32, x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + +def fast_pow2(x): + bits_x = (x + 127) << 23 + return T.reinterpret(T.float32, bits_x) + + +def fast_round_scale(amax, fp8_max_inv): + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + +@tilelang.jit(pass_configs=pass_configs) +def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False): + M = T.dynamic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): + with T.Kernel( + T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m,), scale_dtype) + s_local = T.alloc_fragment((blk_m,), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + +def act_quant(x: torch.Tensor, + block_size: int = 128, + scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})") + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + + +@tilelang.jit(pass_configs=pass_configs) +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32): + assert out_dtype in [BF16, T.float32] + + M = T.dynamic("M") + group_size = 128 + block_M = 32 + block_N = 128 + block_K = 128 + + @T.prim_func + def fp8_gemm_kernel_( + A: T.Tensor[(M, K), FP8], + B: T.Tensor[(N, K), FP8], + C: T.Tensor[(M, N), out_dtype], + scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], + scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32], + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), FP8) + B_shared = T.alloc_shared((block_N, block_K), FP8) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + Scale_C_shared = T.alloc_shared((block_M), FP32) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return fp8_gemm_kernel_ + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, + b_s: torch.Tensor) -> torch.Tensor: + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), ( + "Scaling factor tensors must be contiguous") + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + kernel = fp8_gemm_kernel(N, K) + kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) + return c + + +@tilelang.jit(out_idx=[4], pass_configs=pass_configs) +def fp8_index_kernel(h: int, d: int): + b = T.dynamic("b") + m = T.dynamic("m") + n = T.dynamic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Args: + q (torch.Tensor): The Q tensor, must be contiguous. + q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. + k (torch.Tensor): The K tensor, must be contiguous. + k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) diff --git a/tilelang/original/examples/deepseek_v32/inference/model.py b/tilelang/original/examples/deepseek_v32/inference/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e7468f0587e03de4d36d637fa2e149624c904b --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/model.py @@ -0,0 +1,972 @@ +import math +from dataclasses import dataclass +from typing import Tuple, Optional, Literal + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + +from kernel import act_quant, fp8_gemm, fp8_index + +world_size = 1 +rank = 0 +block_size = 128 + + +@dataclass +class ModelArgs: + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + scale_fmt (Optional[str]): Format for quantization scale. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + index_head_dim (int): Dimension for index head. + index_topk (int): Top-k for index head. + """ + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + scale_fmt: Optional[str] = None + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + # moe + n_routed_experts: int = 64 + n_shared_experts: int = 2 + n_activated_experts: int = 6 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1. + # mla + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1. + # index + index_n_heads: int = 64 + index_head_dim: int = 128 + index_topk: int = 2048 + + +class ParallelEmbedding(nn.Module): + """ + Embedding layer with parallelism support across distributed processes. + + Args: + vocab_size (int): Vocabulary size. + dim (int): Embedding dimension. + """ + + def __init__(self, vocab_size: int, dim: int): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" + self.part_vocab_size = (vocab_size // world_size) + self.vocab_start_idx = rank * self.part_vocab_size + self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for parallel embedding layer. + + Args: + x (torch.Tensor): Input tensor containing token indices. + + Returns: + torch.Tensor: Embedded representations. + + Raises: + ValueError: If `world_size` is not defined. + """ + if world_size > 1: + mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + x = x - self.vocab_start_idx + x[mask] = 0 + y = F.embedding(x, self.weight) + if world_size > 1: + y[mask] = 0 + dist.all_reduce(y) + return y + + +def linear(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + scale_fmt: Optional[str] = None) -> torch.Tensor: + """ + Applies a linear transformation to the incoming data: y = xA^T + b. + This function supports specialized implementations based on quantization + and tensor formats. + + Args: + x (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. It may be quantized and + requires dequantization for certain cases. + bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. + scale_fmt (Optional[str]): The format of scaling factors. + + Returns: + torch.Tensor: The result of the linear transformation, which may involve + quantization-aware computations depending on the input parameters. + + Notes: + - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version + is used for computation. + - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. + """ + assert bias is None + + if weight.dtype != torch.float8_e4m3fn: + return F.linear(x, weight) + else: + x, scale = act_quant(x, block_size, scale_fmt) + return fp8_gemm(x, scale, weight, weight.scale) + + +class Linear(nn.Module): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + dtype = torch.bfloat16 + scale_fmt: Optional[str] = None + + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + self.weight.scale = self.scale = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + else: + self.register_parameter("scale", None) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return linear(x, self.weight, self.bias, self.scale_fmt) + + +class ColumnParallelLinear(Linear): + """ + Linear layer with column parallelism, splitting output features across distributed processes. + + Args: + in_features (int): Number of input features. + out_features (int): Total number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): + assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" + self.part_out_features = out_features // world_size + super().__init__(in_features, self.part_out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for column parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with column-parallel computation. + """ + y = linear(x, self.weight, self.bias, self.scale_fmt) + return y + + +class RowParallelLinear(Linear): + """ + Linear layer with row parallelism, splitting input features across distributed processes. + + Args: + in_features (int): Total number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = False, + reduce_output=True, + dtype=None): + assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" + self.part_in_features = in_features // world_size + self.reduce_output = reduce_output + super().__init__(self.part_in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for row parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with row-parallel computation. + """ + y = linear(x, self.weight, None, self.scale_fmt) + if self.reduce_output and world_size > 1: + y = y.float() + dist.all_reduce(y) + if self.bias is not None: + y += self.bias + return y.type_as(x) + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (RMSNorm). + + Args: + dim (int): Dimension of the input tensor. + eps (float): Epsilon value for numerical stability. Defaults to 1e-6. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None): + """ + Forward pass for RMSNorm. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor with the same shape as input. + """ + dtype = x.dtype + if residual is None: + x = x.float() + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return (self.weight * x).to(dtype) + else: + x = residual = x.float() + residual.float() + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return (self.weight * x).to(dtype), residual.to(dtype) + + +class LayerNorm(nn.Module): + """ + Layer Normalization. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x) + + +def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim(num_rotations, dim, base, max_seq_len): + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + freqs = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if seqlen > args.original_seq_len: + low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + hidden_size = x.size(-1) + return hadamard_transform(x, scale=hidden_size**-0.5) + + +class Indexer(torch.nn.Module): + + def __init__(self, args: ModelArgs): + super().__init__() + self.dim: int = args.dim + self.n_heads: int = args.index_n_heads + self.n_local_heads = args.index_n_heads // world_size + self.head_dim: int = args.index_head_dim + self.rope_head_dim: int = args.qk_rope_head_dim + self.index_topk: int = args.index_topk + self.q_lora_rank: int = args.q_lora_rank + self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim) + self.wk = Linear(self.dim, self.head_dim) + self.k_norm = LayerNorm(self.head_dim) + self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype()) + self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = args.scale_fmt + + self.register_buffer( + "k_cache", + torch.zeros( + args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), + persistent=False) + self.register_buffer( + "k_scale_cache", + torch.zeros( + args.max_batch_size, + args.max_seq_len, + self.head_dim // block_size, + dtype=torch.float32), + persistent=False) + + def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor]): + bsz, seqlen, _ = x.size() + end_pos = start_pos + seqlen + q = self.wq_b(qr) + q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_pe, q_nope], dim=-1) + k = self.wk(x) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2) + k = torch.cat([k_pe, k_nope], dim=-1) + q = rotate_activation(q) + k = rotate_activation(k) + q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt) + self.k_cache[:bsz, start_pos:end_pos] = k_fp8 + self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale + weights = self.weights_proj(x) * self.n_heads**-0.5 + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + index_score = fp8_index(q_fp8.contiguous(), weights, + self.k_cache[:bsz, :end_pos].contiguous(), + self.k_scale_cache[:bsz, :end_pos].contiguous()) + if mask is not None: + index_score += mask + topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1] + topk_indices_ = topk_indices.clone() + dist.broadcast(topk_indices_, src=0) + assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}" + return topk_indices + + +def weight_dequant(weight, scale): + shape = weight.shape + assert weight.dim() == 2 + weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, + block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size) + weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view( + shape[0] // block_size, shape[1] // block_size, block_size, + block_size).transpose(1, 2).contiguous().view(shape) + return weight + + +class MLA(nn.Module): + """ + Multi-Head Latent Attention (MLA) Layer. + + Attributes: + dim (int): Dimensionality of the input features. + n_heads (int): Number of attention heads. + n_local_heads (int): Number of local attention heads for distributed systems. + q_lora_rank (int): Rank for low-rank query projection. + kv_lora_rank (int): Rank for low-rank key/value projection. + qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. + qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. + qk_head_dim (int): Total dimensionality of query/key projections. + v_head_dim (int): Dimensionality of value projections. + softmax_scale (float): Scaling factor for softmax in attention computation. + """ + + def __init__(self, args: ModelArgs): + super().__init__() + self.dim = args.dim + self.n_heads = args.n_heads + self.n_local_heads = args.n_heads // world_size + self.q_lora_rank = args.q_lora_rank + self.kv_lora_rank = args.kv_lora_rank + self.qk_nope_head_dim = args.qk_nope_head_dim + self.qk_rope_head_dim = args.qk_rope_head_dim + self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim + self.v_head_dim = args.v_head_dim + + self.wq_a = Linear(self.dim, self.q_lora_rank) + self.q_norm = RMSNorm(self.q_lora_rank) + self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) + self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) + self.kv_norm = RMSNorm(self.kv_lora_rank) + self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) + self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) + self.softmax_scale = self.qk_head_dim**-0.5 + if args.max_seq_len > args.original_seq_len: + mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.indexer = Indexer(args) + + self.register_buffer( + "kv_cache", + torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), + persistent=False) + self.register_buffer( + "pe_cache", + torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), + persistent=False) + self.dequant_wkv_b = None + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor]): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + start_pos (int): Starting position in the sequence for caching. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + end_pos = start_pos + seqlen + qr = self.q_norm(self.wq_a(x)) + q = self.wq_b(qr) + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + kv = self.wkv_a(x) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv = self.kv_norm(kv) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + self.kv_cache[:bsz, start_pos:end_pos] = kv + self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + if mask is not None: # MHA prefill + q = torch.cat([q_nope, q_pe], dim=-1) + kv = self.wkv_b(kv) + kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) + scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale + + # indexer + topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) + index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), + device=x.device).scatter_(-1, topk_indices, 0) + index_mask += mask + scores += index_mask.unsqueeze(2) + + scores = scores.softmax(dim=-1, dtype=torch.float32) + x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v) + else: # MHA decode + if self.dequant_wkv_b is None and self.wkv_b.scale is not None: + self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale) + wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b + wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) + q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) + scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), + self.kv_cache[:bsz, :end_pos].float()) + + torch.einsum("bshr,btr->bsht", q_pe.float(), + self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale + + # indexer + topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) + index_mask = torch.full((bsz, 1, end_pos), float("-inf"), + device=x.device).scatter_(-1, topk_indices, 0) + scores += index_mask.unsqueeze(2) + + scores = scores.softmax(dim=-1, dtype=torch.float32) + x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos]) + x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) + x = self.wo(x.flatten(2)) + return x + + +class MLP(nn.Module): + """ + Multi-Layer Perceptron (MLP) used as a feed-forward layer. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ + + def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True): + """ + Initializes the MLP layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ + super().__init__() + self.w1 = ColumnParallelLinear(dim, inter_dim) + self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output) + self.w3 = ColumnParallelLinear(dim, inter_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MLP layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after MLP computation. + """ + return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) + + +class Gate(nn.Module): + """ + Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. + + Attributes: + dim (int): Dimensionality of input features. + topk (int): Number of top experts activated for each input. + n_groups (int): Number of groups for routing. + topk_groups (int): Number of groups to route inputs to. + score_func (str): Scoring function ('softmax' or 'sigmoid'). + route_scale (float): Scaling factor for routing weights. + weight (torch.nn.Parameter): Learnable weights for the gate. + bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. + """ + + def __init__(self, args: ModelArgs): + """ + Initializes the Gate module. + + Args: + args (ModelArgs): Model arguments containing gating parameters. + """ + super().__init__() + self.dim = args.dim + self.topk = args.n_activated_experts + self.n_groups = args.n_expert_groups + self.topk_groups = args.n_limited_groups + self.score_func = args.score_func + self.route_scale = args.route_scale + self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) + self.bias = nn.Parameter(torch.empty(args.n_routed_experts, + dtype=torch.float32)) if self.dim == 7168 else None + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the gating mechanism. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. + """ + scores = linear(x.float(), self.weight.float()) + if self.score_func == "softmax": + scores = scores.softmax(dim=-1) + else: + scores = scores.sigmoid() + original_scores = scores + if self.bias is not None: + scores = scores + self.bias + if self.n_groups > 1: + scores = scores.view(x.size(0), self.n_groups, -1) + if self.bias is None: + group_scores = scores.amax(dim=-1) + else: + group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) + indices = group_scores.topk(self.topk_groups, dim=-1)[1] + mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) + scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) + indices = scores.topk(self.topk, dim=-1)[1] + weights = original_scores.gather(1, indices) + if self.score_func == "sigmoid": + weights /= weights.sum(dim=-1, keepdim=True) + weights *= self.route_scale + return weights, indices + + +class Expert(nn.Module): + """ + Expert layer for Mixture-of-Experts (MoE) models. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ + + def __init__(self, dim: int, inter_dim: int): + """ + Initializes the Expert layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ + super().__init__() + self.w1 = Linear(dim, inter_dim) + self.w2 = Linear(inter_dim, dim) + self.w3 = Linear(dim, inter_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Expert layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert computation. + """ + return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) + + +class MoE(nn.Module): + """ + Mixture-of-Experts (MoE) module. + + Attributes: + dim (int): Dimensionality of input features. + n_routed_experts (int): Total number of experts in the model. + n_local_experts (int): Number of experts handled locally in distributed systems. + n_activated_experts (int): Number of experts activated for each input. + gate (nn.Module): Gating mechanism to route inputs to experts. + experts (nn.ModuleList): List of expert modules. + shared_experts (nn.Module): Shared experts applied to all inputs. + """ + + def __init__(self, args: ModelArgs): + """ + Initializes the MoE module. + + Args: + args (ModelArgs): Model arguments containing MoE parameters. + """ + super().__init__() + self.dim = args.dim + assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" + self.n_routed_experts = args.n_routed_experts + self.n_local_experts = args.n_routed_experts // world_size + self.n_activated_experts = args.n_activated_experts + self.experts_start_idx = rank * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + self.gate = Gate(args) + self.experts = nn.ModuleList([ + Expert(args.dim, args.moe_inter_dim) + if self.experts_start_idx <= i < self.experts_end_idx else None + for i in range(self.n_routed_experts) + ]) + self.shared_experts = MLP( + args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MoE module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert routing and computation. + """ + shape = x.size() + x = x.view(-1, self.dim) + weights, indices = self.gate(x) + y = torch.zeros_like(x, dtype=torch.float32) + counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() + for i in range(self.experts_start_idx, self.experts_end_idx): + if counts[i] == 0: + continue + expert = self.experts[i] + idx, top = torch.where(indices == i) + y[idx] += expert(x[idx]) * weights[idx, top, None] + y += self.shared_experts(x) + if world_size > 1: + dist.all_reduce(y) + return y.type_as(x).view(shape) + + +class Block(nn.Module): + """ + Transformer block combining attention and feed-forward layers. + + Attributes: + attn (nn.Module): Attention layer (MLA). + ffn (nn.Module): Feed-forward network (MLP or MoE). + attn_norm (nn.Module): Layer normalization for attention. + ffn_norm (nn.Module): Layer normalization for feed-forward network. + """ + + def __init__(self, layer_id: int, args: ModelArgs): + """ + Initializes the Transformer block. + + Args: + layer_id (int): Layer index in the transformer. + args (ModelArgs): Model arguments containing block parameters. + """ + super().__init__() + self.attn = MLA(args) + self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) + self.attn_norm = RMSNorm(args.dim) + self.ffn_norm = RMSNorm(args.dim) + + def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, + freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position in the sequence. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + + Returns: + torch.Tensor: Output tensor after block computation. + """ + if residual is None: + x, residual = self.attn_norm(x), x + else: + x, residual = self.attn_norm(x, residual) + x = self.attn(x, start_pos, freqs_cis, mask) + x, residual = self.ffn_norm(x, residual) + x = self.ffn(x) + return x, residual + + +class Transformer(nn.Module): + """ + Transformer model with positional embeddings, multiple layers, and output projection. + + Attributes: + max_seq_len (int): Maximum sequence length for the transformer. + embed (nn.Module): Embedding layer for input tokens. + layers (torch.nn.ModuleList): List of transformer blocks. + norm (nn.Module): Layer normalization applied after all blocks. + head (nn.Module): Output projection layer mapping to vocabulary size. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + """ + + def __init__(self, args: ModelArgs): + """ + Initializes the Transformer model. + + Args: + args (ModelArgs): Model arguments containing transformer parameters. + """ + global world_size, rank + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 + Linear.scale_fmt = args.scale_fmt + super().__init__() + self.max_seq_len = args.max_seq_len + self.embed = ParallelEmbedding(args.vocab_size, args.dim) + self.layers = torch.nn.ModuleList() + for layer_id in range(args.n_layers): + self.layers.append(Block(layer_id, args)) + self.norm = RMSNorm(args.dim) + # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later. + self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32) + self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int = 0): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + seqlen = tokens.size(1) + freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen] + mask = torch.full( + (seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None + h, residual = self.embed(tokens), None + for layer in self.layers: + h, residual = layer(h, residual, start_pos, freqs_cis, mask) + h, _ = self.norm(h, residual) + logits = self.head(h[:, -1].float()) + if world_size > 1: + all_logits = [torch.empty_like(logits) for _ in range(world_size)] + dist.all_gather(all_logits, logits) + logits = torch.cat(all_logits, dim=-1) + return logits + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.manual_seed(0) + args = ModelArgs() + x = torch.randint(0, args.vocab_size, (2, 128)) + model = Transformer(args) + print(model(x).size()) diff --git a/tilelang/original/examples/deepseek_v32/inference/requirements.txt b/tilelang/original/examples/deepseek_v32/inference/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..604fed552ca3f44307e1fe3a27bab5ba01c3bc9e --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/inference/requirements.txt @@ -0,0 +1,5 @@ +torch +transformers +safetensors +fast_hadamard_transform +tilelang==0.1.6 \ No newline at end of file diff --git a/tilelang/original/examples/deepseek_v32/sparse_mla_bwd.py b/tilelang/original/examples/deepseek_v32/sparse_mla_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d8035c1ba0e10d4aa7cfadae0d8017ba1333f097 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/sparse_mla_bwd.py @@ -0,0 +1,341 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + B, + S, + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + shape = [B, S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + B, + S_kv, + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + dkv_shape = [B, S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): + T.copy( + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, + }, +) +def bwd( + B, + S, + S_kv, + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=256, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) + + H_kv = H // kv_group + q_shape = [B, S, H, D + D_tail] + k_shape = [B, S_kv, kv_group, D + D_tail] + o_shape = [B, S, H, D] + indices_shape = [B, S, kv_group, topk] + delta_shape = [B, S, H] + lse_shape = [B, S, H] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) + + max_kv_i = s_i + + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = Indices[by, s_i, bz, i_i * BS + bi_i] <= max_kv_i + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] + + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert kv.shape[0] == B + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (B, S, kv_group, topk) + assert lse.shape == (B, S, H) + + # Get kernels + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True): + # Prepare data + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/tilelang/original/examples/deepseek_v32/sparse_mla_fwd.py b/tilelang/original/examples/deepseek_v32/sparse_mla_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c4d2f0463da163b05832f600ccc44c66ded58a --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/sparse_mla_fwd.py @@ -0,0 +1,296 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=2, + threads=256, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + O_shared = T.alloc_shared([H_per_block, D], dtype) + Lse_shared = T.alloc_shared([H_per_block], accum_dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, O_shared) + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + T.copy(sumexp, Lse_shared) + T.copy(sumexp, Lse[b_i, s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + kernel = sparse_mla_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) + out, lse = kernel(q, kv, indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + SKV=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, + ) diff --git a/tilelang/original/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/tilelang/original/examples/deepseek_v32/sparse_mla_fwd_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..54e1a72090152dbb0d1e325784774f4f42efb5a9 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -0,0 +1,438 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from tilelang.engine.callback import register_cuda_postproc_callback +import argparse + + +@tilelang.jit( + out_idx=[-2, -1], + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + ], +) +def sparse_mla_fwd( + batch, + seq_len, + seq_len_kv, + heads, + dim, + tail_dim, + topk, + kv_stride, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=0, + threads=384, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + assert NI % 2 == 0, "NI should be a multiple of 2" + D = dim + D_tail = tail_dim + KV_stride = kv_stride + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): + Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) + Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) + K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) + K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + is_kv_valid = T.alloc_shared([BI], "bool", scope="shared") + + acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + indices_local = T.alloc_local([1], indices_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + b_i, g_i = by, bz + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + q_i = q_start_index_s[0] + s_i + max_kv_i = (q_i + 1 - KV_stride) // KV_stride + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + tx = T.get_thread_binding() + + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(H_per_block): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + if is_kv_valid[r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + if is_kv_valid[r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = 512 + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + if q_start_index_s != 0: + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) + CP0 = q_start_index_s == 0 + + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) + if print_kernel: + print(kernel.get_kernel_source()) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + if return_kernel: + return kernel + if q_start_index_s == 0 and kv_stride > 1: + out[:, : kv_stride - 1, :, :] = 0 + return out, lse + + +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + if q_start_index_s is None: + q_start_index_s = sk * kv_stride - sq + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + num_kv_per_index = 1 + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : kv_stride - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd_pipelined( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True +): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + + def fn(): + out, lse = kernel(q, kv, indices, q_start_s_index_t) + if q_start_s_index == 0 and KV_stride > 1: + out[:, : KV_stride - 1, :, :] = 0 + return out, lse + + tl_out, tl_lse = fn() + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) + # print(f"tl_out: {tl_out}") + # print(f"ref_out: {ref_out}") + + torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=10, + warmup=10, + ) + print(f"Average time: {ms:.3f} ms") + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test_correctness", action="store_true") + args = parser.parse_args() + if args.test_correctness: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + else: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) diff --git a/tilelang/original/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/tilelang/original/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7e879ba6e4fc1155660ad6be10da917c7a5ad5 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -0,0 +1,41 @@ +# ruff: noqa +import tilelang +import tilelang.testing + +import topk_selector +import fp8_lighting_indexer +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import sparse_mla_bwd + + +def test_example_topk_selector(): + topk_selector.test_topk_selector() + + +def test_example_fp8_lighting_indexer(): + fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_sparse_mla_fwd(): + # small shapes for testing + sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_sparse_mla_fwd_pipelined(): + # small shapes for testing + sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_sparse_mla_bwd(): + sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/deepseek_v32/topk_selector.py b/tilelang/original/examples/deepseek_v32/topk_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..244f74c69615430ca32ea847529639a1ccf65cca --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/topk_selector.py @@ -0,0 +1,244 @@ +import torch +import tilelang +import tilelang.language as T + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, +} + + +def convert_to_uint16(x): + hval = T.Cast(T.float16, x) + bits_uint = T.reinterpret(T.uint16, hval) + bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) + return bits_uint >> 8 + + +def convert_to_uint32(x): + bits_uint = T.reinterpret(T.uint32, x) + bits_uint = T.if_then_else( + x < 0, + ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)), + bits_uint | T.Cast(T.uint32, (0x80000000)), + ) + return bits_uint + + +@tilelang.jit(pass_configs=pass_configs) +def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + RADIX = 1 << 8 + BLOCK_SIZE = 1024 + SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K + + @T.prim_func + def tl_topk_kernel( + input: T.Tensor[(batch, seq_len), in_dtype], + index: T.Tensor[(batch, topk), out_dtype], + starts: T.Tensor[(batch), out_dtype], + ends: T.Tensor[(batch), out_dtype], + ): + with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): + tx = T.get_thread_binding() + + s_threshold_bin_id = T.alloc_shared([1], T.int32) + s_histogram = T.alloc_shared([RADIX + 1], T.int32) + s_num_input = T.alloc_shared([2], T.int32) + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) + + l_threshold_bin_id = T.alloc_var(T.int32) + l_new_topk = T.alloc_var(T.int32) + l_num_input = T.alloc_var(T.int32) + l_bin_id32 = T.alloc_var(T.int32) + l_val = T.alloc_var(T.int32) + l_start_pos = T.alloc_var(T.int32) + l_start_idx = T.alloc_var(T.int32) + l_end_idx = T.alloc_var(T.int32) + l_out_pos = T.alloc_var(T.int32) + + l_new_topk = topk + l_start_idx = starts[bx] + l_end_idx = ends[bx] + + # stage 1: use 8bit to do quick topk + T.fill(s_histogram, 0) + T.fill(s_num_input[0], 0) + + T.sync_threads() + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + input_idx = s * BLOCK_SIZE + tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + inval_int16 = convert_to_uint16(input[bx, input_idx]) + T.atomic_add(s_histogram[inval_int16], 1) + T.sync_threads() + + # cumsum + if tx < RADIX: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + s_histogram[tx] = l_val + + # find threshold bin id + T.sync_threads(3, RADIX) + if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + # collect all elements with exponent ≥ threshold + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + T.sync_threads() + input_idx = s * BLOCK_SIZE + tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + bin_id = convert_to_uint16(input[bx, input_idx]) + l_bin_id32 = T.Cast(T.int32, bin_id) + if l_bin_id32 > l_threshold_bin_id: + # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + index[bx, pos] = input_idx + + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + # pos = s_num_input[0] + pos = T.atomic_add(s_num_input[0], 1, return_prev=True) + s_input_idx[0, pos] = input_idx + + # stage 2: tail pass + for round in T.serial(4): + if l_new_topk <= 0: + T.loop_break() + + r_idx = round % 2 + l_start_pos = topk - l_new_topk + + T.sync_threads() + T.fill(s_histogram, 0) + if tx == 0: + s_num_input[r_idx ^ 1] = 0 + T.sync_threads() + + l_num_input = s_num_input[r_idx] + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) + T.atomic_add(s_histogram[l_bin_id32], 1) + T.sync_threads() + # cumsum + if tx < RADIX: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + s_histogram[tx] = l_val + + # find threshold bin id + T.sync_threads(3, RADIX) + if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + T.sync_threads() + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) + if l_bin_id32 > l_threshold_bin_id: + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + if round == 3: + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + if l_out_pos < topk: + index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + else: + pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) + s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + + return tl_topk_kernel + + +def tl_topk(input, starts, ends, topk): + batch, seq_len = input.shape + indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device) + kernel = tl_topk_impl(topk) + kernel(input, indexes, starts, ends) + return indexes + + +def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): + batch = 64 + seq_len = 32 * 1024 + topk = 2048 + torch.manual_seed(1) + input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() + starts = torch.zeros(batch, dtype=torch.int32).cuda() + ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len + + indexes = tl_topk(input, starts, ends, topk) + print(indexes) + + indexes_ref = torch.topk(input, topk, dim=-1)[1] + print(indexes_ref) + + # indexes_ref = fast_topk(input, topk) + # print(indexes_ref) + + # Calculate intersection of out_ref and out_trt + for i in range(batch): + ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() + trt_np = indexes[i].cpu().to(torch.int32).numpy() + + set_ref = set(ref_np) + set_trt = set(trt_np) + intersection = set_ref & set_trt + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + # Performance test with CUDA events + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Warmup + for _ in range(5): + _ = tl_topk(input, starts, ends, topk) + torch.cuda.synchronize() + + n_iters = 20 + start_event.record() + for _ in range(n_iters): + _ = tl_topk(input, starts, ends, topk) + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms") + + # Torch topk time + start_event.record() + for _ in range(n_iters): + _ = torch.topk(input, topk, dim=-1)[1] + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") + + +if __name__ == "__main__": + test_topk_selector() diff --git a/tilelang/original/examples/deepseek_v32/utils.py b/tilelang/original/examples/deepseek_v32/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7252e171108aa13396f6d3e91d84d04de1d3c17 --- /dev/null +++ b/tilelang/original/examples/deepseek_v32/utils.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +import contextlib +import functools +import logging +import os +import sys +from enum import Enum +from functools import lru_cache +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +from packaging import version + + +def _is_equal(a, b): + if isinstance(a, torch.Tensor): + return a is b + # Whitelist of types that are safe to compare by value for caching. + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): + return a == b + # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. + return False + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: Optional[Tuple] = None + last_kwargs: Optional[Dict] = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if last_args is not None and last_kwargs is not None: + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + # For Tensors, check for object identity. For other types, check for equality. + # Python caches small integers, so `is` works for them but not for large integers like 4096. + if ( + all(_is_equal(a, b) for a, b in zip(args, last_args)) + and set(kwargs.keys()) == set(last_kwargs.keys()) + and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): + seq_idx = cu_seqlens.new_zeros(seq_len + 1) + seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx)) + seq_idx.cumsum_(0) + return seq_idx[:-1] + + +@tensor_cache +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i + return seq_idx_for_q + + +@tensor_cache +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: + cu_seqlen_ks_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + return cu_seqlen_ks_for_each_q.int() + + +@tensor_cache +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: + cu_seqlen_ke_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] + cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) + return cu_seqlen_ke_for_each_q.int() + + +@tensor_cache +def cal_ks_ke_from_cu_seqlen_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False, +): + """ + seq_len: seq len per cp rank + balanced cp slice assignment: 0 1 2 3 3 2 1 0 + """ + n_seq = len(cu_seqlens_q) - 1 + assert n_seq > 0 + assert cu_seqlens_q.shape == (n_seq + 1,) + seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len * cp_size) + qs = cu_seqlens_q.gather(0, seq_idx) + pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs + if offs_q is not None: + assert offs_q.shape == (n_seq,), offs_q.shape + qoff = offs_q.gather(0, seq_idx) + pos += qoff + if cu_seqlens_k is None or cu_seqlens_k is cu_seqlens_q: + ks = qs + else: + assert cu_seqlens_k.shape == (n_seq + 1,) + ks = cu_seqlens_k.gather(0, seq_idx) + ke = ks + (pos + 1) // kv_stride + + if cp_size == 1: + pass + elif balanced_cp: + assert cp_size % 2 == 0, cp_size + + def f(x: torch.Tensor): + chunks = x.chunk(cp_size * 2) + return torch.cat( + [ + chunks[cp_rank], + chunks[cp_size - cp_rank - 1], + ] + ) + + ks = f(ks) + ke = f(ke) + else: + ks = ks.chunk(cp_size)[cp_rank] + ke = ke.chunk(cp_size)[cp_rank] + + return ks, ke + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512): + total_seqlen = per_cp_seqlen * cp_size + + cu_seqlens = torch.randint(0, average_q_len * 2, (total_seqlen // average_q_len * 2,)).cuda() + last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] + cu_seqlens = cu_seqlens[:last_seq_id] + + if cu_seqlens.sum() < total_seqlen: + cu_seqlens = torch.cat([cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]) + + cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0) + cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0) + cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]]) + cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]]) + cu_seqlens_qe = cu_seqlens_cumsum.clone() + cu_seqlens_ke = cu_seqlens_k_cumsum.clone() + + cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + seq_len=total_seqlen, + ) + cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + cu_seqlens_ke=cu_seqlens_ke, + q_start_idxs=torch.zeros_like(cu_seqlens_qs), + seq_len=total_seqlen, + kv_stride=kv_stride, + ) + + assert per_cp_seqlen % 2 == 0 + per_chunk_seqlen = per_cp_seqlen // 2 + slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen) + slice_long = slice( + total_seqlen - (cp_rank + 1) * per_chunk_seqlen, + total_seqlen - cp_rank * per_chunk_seqlen, + ) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) + assert len(ks) == len(ke) == per_cp_seqlen + return ks, ke + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 + + +if __name__ == "__main__": + seq_len = 32768 + cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") + last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] + cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + + from tilelang.profiler import do_bench + + fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) # noqa: E731 + ms = do_bench(fn, warmup=25, rep=100) diff --git a/tilelang/original/examples/demo.py b/tilelang/original/examples/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3a35d83132f7e8118c644f7fde6461021f2ad8fe --- /dev/null +++ b/tilelang/original/examples/demo.py @@ -0,0 +1,47 @@ +import tilelang +import tilelang.language as T + +@tilelang.jit(target="hip") +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M, N, K = 1024, 1024, 1024 +block_M, block_N, block_K = 128, 128, 32 + +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 使用 PyTorch 等与 HIP 兼容的数据在 DCU 上测试 +import torch +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +matmul_relu_kernel(a, b, c) +ref_c = torch.relu(a @ b) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches reference.") diff --git a/tilelang/original/examples/dequantize_gemm/README.md b/tilelang/original/examples/dequantize_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c6116775e57b9d02df3c5f49763fb9d9df509fc --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/README.md @@ -0,0 +1,39 @@ + +### Dequantization GEMM + +An example of implementing a dequantization GEMM: + +```python +@T.prim_func +def dequant_matmul( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), +): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + + T.clear(Ct_local) + for k in T.Pipelined( + T.ceildiv(K, block_K), + num_stages=num_stages + ): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert("int", 8)( + num_bits, + B_local[i, j // 2], + j % 2, + dtype=in_dtype, + ) + T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct[bx * block_N, by * block_M]) +``` + +**Notes:** Dequantize GEMM with magic layout transformations to get optimal performance can be found at project [BitBLAS](https://github.com/microsoft/BitBLAS), example kernels can be found at `testing/python/kernel/test_tilelang_dequantize_gemm.py`, detailed explanation and examples is coming soon. diff --git a/tilelang/original/examples/dequantize_gemm/dequantize_utils.py b/tilelang/original/examples/dequantize_gemm/dequantize_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90a6265ffa4bf22c3d583e74e53066161c80a37a --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/dequantize_utils.py @@ -0,0 +1,148 @@ +import torch + + +def torch_convert_bit_twiddling(tensor): + """ + This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`. + + Parameters: + tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K). + + Returns: + torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns. + + Raises: + AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`. + """ + assert tensor.dim() == 2 and tensor.dtype == torch.uint8 + N, K = tensor.shape + assert K % 2 == 0, "Number of columns must be even" + + # Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA + val0 = tensor[:, 0::2].to(torch.int32) + val1 = tensor[:, 1::2].to(torch.int32) + val_concat = (val0 << 8) | val1 # (N, K//2), uint32 + + # Expand to match output shape where each pair generates 4 values + val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4) + + # Positional encoding for bit-twiddling logic + pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,) + + # Bit masks for decoding (as uint32 for CUDA compatibility) + mask = 0b1000000111000000 + mask1 = 0b1000000000000000 + mask2 = 0b0000000110000000 + mask3 = 0b0000000001000000 + + # Calculate results for all 4 positions in parallel + res0 = val_concat_expanded & mask + res1 = (val_concat_expanded << 3) & mask + res2 = (val_concat_expanded << 6) & mask + res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3) + + # Select the correct result based on position + bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) + + # Convert to uint16 for .view(torch.bfloat16) + bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) + bf16_bf16 = bf16_uint16.view(torch.bfloat16) + + # Avoid integer overflow by using a float32 multiplier for the exponent scaling + bf16_new = bf16_bf16 * (2.0**126) + + return bf16_new + + +def torch_convert(tensor, scale_size=None, Scale=None): + """ + Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding. + + Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input. + + Parameters: + tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values. + scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale. + Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size]. + + Returns: + torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values. + """ + + def _convert(val, pos, scale=None): + assert val.dtype == torch.uint8 + # val = val.view(torch.int8) + mask = (1 << 4) - 1 + f4 = ((val >> (pos * 4)) & mask).to(torch.int16) + s = f4 >> 3 + e_f4 = (f4 & 6) >> 1 + e_f16 = e_f4 + 126 + if scale is not None: + e_f16 = min(e_f16 + scale, (1 << 8) - 1) + m_f4 = f4 & 1 + m_f16 = m_f4 + val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF + lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) + return lower_16_bits.view(torch.bfloat16) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + if scale_size is not None: + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size]) + else: + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +def print_bit(name, val): + """ + Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor. + + Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`. + + Parameters: + name (str): Label printed before the binary representation. + val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. + """ + val_cpu = val.cpu().item() + binary_repr = f"{val_cpu:032b}" + print(name, binary_repr) + + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): + x_mask = torch.isfinite(x) + y_mask = torch.isfinite(y) + if not torch.all(x_mask == y_mask): + print_red_warning(f"{name} Error: isfinite mask mismatch") + if raise_assert: + raise AssertionError + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") + if raise_assert: + raise AssertionError + x = x.masked_fill(~x_mask, 0) + y = y.masked_fill(~y_mask, 0) + sim = calc_sim(x, y, name) + diff = (1.0 - sim).item() + print(f"{diff=}") + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff=}") + if raise_assert: + raise AssertionError diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9c945b36083073e020b424725da949614d4df0 --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -0,0 +1,443 @@ +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tvm import tir +import torch +from dequantize_utils import torch_convert_bit_twiddling, torch_convert + + +def get_configs(): + """ + Return a list of tuning configuration dictionaries for the autotuned matmul kernel. + + Each dictionary is a single combination (Cartesian product) of the following parameters: + - block_M: tile size for M dimension (one of 64, 128, 256) + - block_N: tile size for N dimension (one of 64, 128, 256) + - block_K: tile size for K dimension + - num_stages: pipeline stages for K-loop (0 or 2) + - threads: number of threads to launch (128, 256, or 512) + - split: K-splitting factor (1 or 2) + + Returns: + list[dict]: List of configuration dicts usable by the autotuner, where each dict maps + the parameter name to its chosen value. + """ + import itertools + + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[128], + num_stages=[0, 2], + threads=[128, 256, 512], + split=[1, 2], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + fast_dequant=True, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. + + This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: + - A: dense input of shape (M, K) with dtype `in_dtype`. + - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. + - C: output of shape (M, N) with dtype `out_dtype`. + + The generated kernel supports two dequantization paths: + - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. + - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. + + Important behavior and requirements: + - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. + - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. + - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. + - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. + - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. + + Parameters that alter kernel layout/behavior (brief): + - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. + - num_stages: number of software pipeline stages for the K-loop. + - threads: number of threads used per kernel block. + - split: extra K-splitting factor; K must be divisible by block_K * split. + - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. + + Returns: + A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. + """ + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shape = (M, K) + B_shape = (N, QK) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + # It requires that the 2 consecutive uint8 elements (16bits) contains 4 fp4 elements in a bit-twiddling way. + # The bit-twiddling way is shown here: The pair (x,y) shows that the bit in this position is the y-th bit of the x-th fp4 element. + # (0,0)(3,0)(3,3)(1,0)(3,1)(3,2)(2,0)(0,1)(0,2)(0,3)(1,1)(1,2)(1,3)(2,1)(2,2)(2,3) + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin. + + This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which: + - Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads). + - Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers. + - Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel. + + Notes and preconditions: + - Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`. + - The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel. + - The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly. + - The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): + # import fast_dequantize plugin + """ + Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer. + + This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters. + + Parameters: + B_shared: Shared-memory buffer containing packed quantized values (packed FP4 layout). + B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine). + + Side effects: + - Imports the external dequantization plugin via `import_source` and invokes `func_name`. + - Writes dequantized BF16 results into `B_dequantize_shared`. + + Notes: + - This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`). + - No value is returned; results are produced by mutation of `B_dequantize_shared`. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + for v in T.vectorized(0, local_compress_size): + index = i * threads * local_compress_size + tx * local_compress_size + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16. + + The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like + `B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It: + - Unpacks 4-bit FP values from the packed uint8 representation in B_shared. + - Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`. + - Writes the dequantized bfloat16 block into B_dequantize_shared. + + Constraints: + - Supports only in_dtype="fp4" and out_dtype=T.bfloat16. + - The helper assumes nbit == 4 and produces bfloat16 values. + - The macro uses a fixed test-scale of 0 (no per-element scaling) as written. + + Returns: + A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. + + This helper extracts the 4-bit field located at the bit position `pos` within the + byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an + exponent `scale` offset to align it with bfloat16 exponent bias, clamps the + resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. + + Parameters: + nbit (int): Number of bits in the packed element; must be 4. + val (tir.PrimExpr): A uint8 value containing packed FP4 elements. + pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. + scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. + dtype (str): Target dtype string; must be T.bfloat16. + + Returns: + tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. + + Notes: + - The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8. + - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 + bit fields and clamps the computed exponent to fit into 8 bits. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we use the max function to limit the exponential part to 8 bits + e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) + return val_bf16 + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared): + """ + Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer. + + This helper: + - Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared. + - Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`. + - Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope. + + Parameters: + B_shared: shared-memory buffer containing packed FP4 data (uint8-packed). + B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values. + + Side effects: + Writes dequantized BF16 values into B_dequantize_shared. No return value. + """ + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_shared[i, j // num_elems_per_byte], + j % num_elems_per_byte, + 0, # No scale for test + dtype=out_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), + ): + """ + Kernel entry for the tiled, pipelined matmul used by the generated prim_func. + + This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: + - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. + - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. + - Pipelines over K in chunks of `block_K` for `num_stages` stages: + - Loads A and packed B tiles into shared memory. + - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. + - Performs a GEMM accumulating into C_local with B transposed. + - Stores the accumulated block from C_local back to the global output C via C_shared. + + Parameters: + - A: input tile of shape (M, K) with dtype `in_dtype`. + - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). + - C: output tensor of shape (M, N) with dtype `out_dtype`. + + Side effects: + - Writes the computed output block into the global tensor `C`. + - Uses and updates shared memory buffers and per-thread accumulators. + + No value is returned. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.annotate_layout( + { + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + + T.clear(C_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared) + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + + +def ref_program_twiddling(A, qB): + """ + Compute reference BF16 matrix multiply using bit-twiddled FP4 quantized B. + + Converts qB (a bit-twiddled, packed FP4 representation of matrix B) back to floating, + performs C = A @ B^T in full precision, and returns the result converted to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K). Treated as floating-point (converted to torch.float for compute). + qB (torch.Tensor): Bit-twiddled, packed FP4 representation of B (quantized). Shape corresponds to B's packed layout. + + Returns: + torch.Tensor: Result matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple(A, qB): + """ + Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB. + + Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning. + + Parameters: + A (torch.Tensor): Left input matrix with shape (M, K). + qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A. + + Returns: + torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def main(m=256, n=256, k=256, fast_dequant=True, tune=False): + """ + Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference. + + This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs. + + Parameters: + m (int): Number of rows of A and output C (default 256). + n (int): Number of columns of B and output C (default 256). + k (int): Inner dimension (columns of A, rows of B) (default 256). + fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True). + tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False). + + Side effects: + - Prints latency and TFLOPs to stdout. + - Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01). + """ + total_flops = 2 * m * n * k + if tune: + kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant) + else: + kernel = matmul( + m, + n, + k, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=4, + fast_dequant=fast_dequant, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + if fast_dequant: + profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + main(256, 256, 256, True) + main(256, 256, 256, False) diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0375a1db8d89c732f3053a8a3280ccfb7df940 --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -0,0 +1,547 @@ +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tvm import tir +import torch +from dequantize_utils import torch_convert_bit_twiddling, torch_convert + + +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be T.bfloat16). + + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we may use the min function to limit the exponential part to 8 bits + # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) + return val_bf16 + + +def get_configs(): + """ + Generate a list of hyperparameter configuration dictionaries for tuning. + + Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', + 'num_stages', 'threads', and 'split'. The function returns the Cartesian + product of the parameter value lists: + - block_M, block_N, block_K: tiling sizes (64, 128, 256) + - num_stages: pipeline stages (0, 2) + - threads: thread counts (128, 256, 512) + - split: K-splitting factor (1, 2) + + Returns: + List[dict]: A list of configuration dictionaries covering all combinations. + """ + import itertools + + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[64, 128, 256], + num_stages=[0, 2], + threads=[128, 256, 512], + split=[1, 2], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. + + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. + """ + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shape = (M, K) + B_shape = (N, QK) + Bias_shape = (M, N) + Scale_shape = (N, K // scale_size) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + Bias_shared_shape = (block_M, block_N) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. + + The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: + - Loads packed FP4 elements from B_shared into per-thread local registers. + - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. + - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). + - Writes the scaled BF16 results into B_dequantize_shared. + + Notes: + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. + - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): + # import fast_dequantize plugin + """ + Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 + in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, + applying per-block scale factors from Scale. + + This routine is a tiled, thread-parallel helper that: + - Imports and calls an external dequantization function (via `import_source`/`func_name`) + to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. + - Loads the corresponding per-block scale entry, interprets it as an exponent bias + (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. + + Parameters: + - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). + - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. + - Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale + = 2^(Scale - 127). + - k: block index along the K dimension used to select the appropriate Scale entries. + + Side effects: + - Mutates B_dequantize_shared in shared memory. + - Calls an external intrinsic function (must be provided by the environment via `import_source` + and `func_name`) to perform the low-level unpacking/dequantization. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + bx = T.get_block_binding(0) + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + Scale_local_thread = T.alloc_local((1,), storage_dtype) + Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) + + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + index_base = i * threads * local_compress_size + tx * local_compress_size + for v in T.vectorized(0, local_compress_size): + index = index_base + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + index_scale = index_base // (scale_size // num_elems_per_byte) + si = index_scale // (block_K // scale_size) + sj = index_scale % (block_K // scale_size) + Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj] + Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.Parallel(local_size): + B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. + + Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. + + Notes: + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. + - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): + """ + Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents. + + Per-element behavior: + - Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte). + - Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16. + - Writes the dequantized BF16 block into B_dequantize_shared. + + Parameters: + - B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout). + - B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results. + - Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element. + - k: current block index along the K dimension (used to select the appropriate slice of Scale). + + Side effects: + - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. + """ + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + + bx = T.get_block_binding(0) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale[ + bx * block_N + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 + dtype=out_dtype, + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), + ): + """ + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + + if with_bias: + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) + + if threads == 512: + T.disable_warp_group_reg_alloc() + + if with_bias: + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared) + T.copy(Bias_shared, C_local) + else: + T.clear(C_local) + + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale, k) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale, k) + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + + +def ref_program_twiddling(A, qB, Scale, Bias=None): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_twiddling_with_bias(A, qB, Scale, Bias): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + Bias (torch.Tensor): Bias tensor with shape (M, N). + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple(A, qB, Scale, Bias=None): + """ + Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16. + + Parameters: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + + Returns: + - 2D bfloat16 tensor C containing the matrix product A · B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple_with_bias(A, qB, Scale, Bias): + """ + Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16. + + Parameters: + + Returns: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + - Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul). + + + Returns: + - 2D bfloat16 tensor C containing the matrix product A · B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): + """ + Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS. + + Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS. + + Parameters: + m (int): Number of rows of A / output rows. Default 256. + n (int): Number of columns of B / output columns. Default 256. + k (int): Reduction dimension. Default 256. + scale_size (int): Size of the per-block scale vector used for dequantization. Default 32. + fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True. + tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False. + + Returns: + None + """ + total_flops = 2 * m * n * k + + if tune: + kernel = matmul( + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) + else: + kernel = matmul( + m, + n, + k, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + + if fast_dequant: + if with_bias: + profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) + else: + if with_bias: + profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + M, N, K = 256, 256, 256 + scale_size = 32 + main(M, N, K, scale_size, fast_dequant=True, with_bias=True) + main(M, N, K, scale_size, fast_dequant=False, with_bias=True) + main(M, N, K, scale_size, fast_dequant=True, with_bias=False) + main(M, N, K, scale_size, fast_dequant=False, with_bias=False) diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90418bc75e73a4345525e7ecc339b254a0ab0c --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -0,0 +1,563 @@ +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tvm import tir +import torch +from dequantize_utils import torch_convert_bit_twiddling, torch_convert + + +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be T.bfloat16). + + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we may use the min function to limit the exponential part to 8 bits + # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) + return val_bf16 + + +def get_configs(): + """ + Generate a list of hyperparameter configuration dictionaries for tuning. + + Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', + 'num_stages', 'threads', and 'split'. The function returns the Cartesian + product of the parameter value lists: + - block_M, block_N, block_K: tiling sizes (64, 128, 256) + - num_stages: pipeline stages (0, 2) + - threads: thread counts (128, 256, 512) + - split: K-splitting factor (1, 2) + + Returns: + List[dict]: A list of configuration dictionaries covering all combinations. + """ + import itertools + + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[64, 128, 256], + num_stages=[0, 1, 2], + threads=[128, 256, 512], + split=[1, 2], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. + + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. + """ + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shape = (M, K) + B_shape = (N, QK) + Bias_shape = (M, N) + Scale_shape = (N, K // scale_size) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + Bias_shared_shape = (block_M, block_N) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. + + The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: + - Loads packed FP4 elements from B_shared into per-thread local registers. + - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. + - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). + - Writes the scaled BF16 results into B_dequantize_shared. + + Notes: + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. + - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k): + # import fast_dequantize plugin + """ + Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 + in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, + applying per-block scale factors from Scale. + + This routine is a tiled, thread-parallel helper that: + - Imports and calls an external dequantization function (via `import_source`/`func_name`) + to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. + - Loads the corresponding per-block scale entry, interprets it as an exponent bias + (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. + + Parameters: + - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). + - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. + - Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale + = 2^(Scale - 127). + - k: block index along the K dimension used to select the appropriate Scale entries. + + Side effects: + - Mutates B_dequantize_shared in shared memory. + - Calls an external intrinsic function (must be provided by the environment via `import_source` + and `func_name`) to perform the low-level unpacking/dequantization. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + bx = T.get_block_binding(0) # noqa: F841 + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + Scale_local_thread = T.alloc_local((1,), storage_dtype) + Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) + + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + index_base = i * threads * local_compress_size + tx * local_compress_size + for v in T.vectorized(0, local_compress_size): + index = index_base + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + index_scale = index_base // (scale_size // num_elems_per_byte) + si = index_scale // (block_K // scale_size) + sj = index_scale % (block_K // scale_size) + Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj] + Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.Parallel(local_size): + B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. + + Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. + + Notes: + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. + - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): + """ + Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents. + + Per-element behavior: + - Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte). + - Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16. + - Writes the dequantized BF16 block into B_dequantize_shared. + + Parameters: + - B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout). + - B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results. + - Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element. + - k: current block index along the K dimension (used to select the appropriate slice of Scale). + + Side effects: + - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. + """ + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + + bx = T.get_block_binding(0) # noqa: F841 + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale_shared[ + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 + dtype=out_dtype, + ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), + ): + """ + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + # To use 1D TMA, the last dim of Scale_shared must have stride=1 + # May use much more shared memory than necessary + Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + + if with_bias: + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) + + if threads == 512: + T.disable_warp_group_reg_alloc() + + if with_bias: + # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], + # Bias_shared) + # T.copy(Bias_shared, C_local) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local) + else: + T.clear(C_local) + + # Use 1D TMA to load Scale + T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared) + + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + + +def ref_program_twiddling(A, qB, Scale, Bias=None): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_twiddling_with_bias(A, qB, Scale, Bias): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + Bias (torch.Tensor): Bias tensor with shape (M, N). + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple(A, qB, Scale, Bias=None): + """ + Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16. + + Parameters: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + + Returns: + - 2D bfloat16 tensor C containing the matrix product A · B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple_with_bias(A, qB, Scale, Bias): + """ + Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16. + + Parameters: + + Returns: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + - Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul). + + + Returns: + - 2D bfloat16 tensor C containing the matrix product A · B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): + """ + Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS. + + Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS. + + Parameters: + m (int): Number of rows of A / output rows. Default 256. + n (int): Number of columns of B / output columns. Default 256. + k (int): Reduction dimension. Default 256. + scale_size (int): Size of the per-block scale vector used for dequantization. Default 32. + fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True. + tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False. + + Returns: + None + """ + total_flops = 2 * m * n * k + + if tune: + kernel = matmul( + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) + else: + kernel = matmul( + m, + n, + k, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + + if fast_dequant: + if with_bias: + profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) + else: + if with_bias: + profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + M, N, K = 256, 256, 256 + scale_size = 32 + main(M, N, K, scale_size, fast_dequant=True, with_bias=True) + main(M, N, K, scale_size, fast_dequant=False, with_bias=True) + main(M, N, K, scale_size, fast_dequant=True, with_bias=False) + main(M, N, K, scale_size, fast_dequant=False, with_bias=False) diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py new file mode 100644 index 0000000000000000000000000000000000000000..37826874bc3dfe2ee980f3b8274d56cca8bed3c6 --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -0,0 +1,434 @@ +import torch +import torch.backends +import tilelang.testing +from tilelang import tvm as tvm +from tvm import DataType +import tilelang.language as T + +tilelang.testing.set_random_seed(0) + + +@tilelang.jit(out_idx=[2]) +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + num_bits=4, +): + from tilelang.quantize import _tir_packed_to_unsigned_convert + + num_elems_per_byte = 8 // num_bits + storage_dtype = T.int8 + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + tx = T.get_thread_binding() + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + tx * local_size_compressed + v + vi = index // (block_K // num_elems_per_byte) + vj = index % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, local_size): + B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // block_K + vj = index % block_K + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + kernel = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + + out = profiler.run_once() + assert out is not None + + def ref_program(A, qB): + import torch + + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program) + + +@tvm.testing.requires_package("bitblas") +def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + transform_b, +): + from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout + from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitterWithLadderTransform, + ) + + from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = T.int8 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + + warp_rows = 4 + warp_cols = 4 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + reduce_k = 1 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = 32 if in_dtype == T.float16 else 64 + chunk = block_K // reduce_k + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + transform_kind_b=transform_b, + num_elems_per_byte=num_elems_per_byte, + ) + + vec_load_qb = 16 + if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) + B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + reduced_accum_res = T.alloc_local(0, accum_dtype) + thread_binding = T.get_thread_binding(0) + rk = T.get_thread_binding(1) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, (block_K // reduce_k)): + vk = rk * (block_K // reduce_k) + k + A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] + + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + t = thread_binding + idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v + vkk = idx % (micro_size_k // num_elems_per_byte) + vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % ( + block_N // micro_size_y + ) + B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + rk=rk, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + rk=rk, + ) + + for j in T.serial(warp_cols): + local_size_b = mma_emitter.local_size_b + T.call_extern( + "handle", + "decode_i4u_to_f16", + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), + 8, + ) + + mma_emitter.mma(A_local, B_dequantize_local, C_local) + + if reduce_k > 1: + for n in T.serial(warp_rows * warp_cols * local_size): + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_local[n], + True, + reduced_accum_res[0], + rk, + dtype="handle", + ) + ) + if rk == 0: + C_local[n] = reduced_accum_res[0] + + if rk == 0: + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + for i, j in T.Parallel(block_M, (block_N // reduce_k)): + vj = rk * (block_N // reduce_k) + j + C[by * block_M + i, bx * block_N + vj] = C_shared[ + i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y + ] + + return main + + +def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + transform_b, +): + import bitblas + + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) + + kernel = tilelang.compile(matmul, out_idx=[2]) + src_code = kernel.get_kernel_source() + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + + # src_code is the generated cuda source + assert src_code is not None + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = T.int8 + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=transform_b, + transpose_matrix=True, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype=in_dtype, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + QLB = ladder_permutate(qB.cpu()).cuda() + QLB = lop3_permutate(QLB.cpu()).cuda() + + kernel(A, QLB, C) + + latency = profiler.do_bench(warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("Ref C: ", ref_c) + print("C: ", C) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_package("bitblas") +def test_run_dequantize_gemm(): + run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128) + + +@tilelang.testing.requires_package("bitblas") +def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3) + + +def main(): + test_run_dequantize_gemm() + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..79345771d6bd2e565547016c39e6ee2a4611201d --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -0,0 +1,284 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from tvm import tir +import itertools +import torch +import argparse + + +def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == T.float16 + assert val.dtype == T.uint8 + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 + # s1e2m1 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + e_f16 = e_f4 + tir.const(14, T.uint16) + m_f4 = f4 & tir.const(1, T.uint16) + m_f16 = m_f4 + val_f16 = tir.reinterpret( + T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16) + ) + # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16) + return val_f16 + + +def torch_convert(tensor): + def print_bit(name, val): + val_cpu = val.cpu().item() + binary_repr = f"{val_cpu:032b}" + print(name, binary_repr) + + def _convert(val, pos): + assert val.dtype == torch.uint8 + val = val.view(torch.int8) + mask = (1 << 4) - 1 + f4 = ((val >> (pos * 4)) & mask).to(torch.int16) + s = f4 >> 3 + e_f4 = (f4 & 6) >> 1 + e_f16 = e_f4 + 14 + m_f4 = f4 & 1 + m_f16 = m_f4 + val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF + lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) + return lower_16_bits.view(torch.float16) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +@tilelang.jit(out_idx=[1]) +def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +def test_fp4_fp16_convert_close(): + N, K = 256, 256 + block_N, block_K = 64, 64 + kernel = test_convert( + N, + K, + block_N, + block_K, + T.float16, + ) + + B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) + tl_out = kernel(B) + ref_out = torch_convert(B) + assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) + print("Pass") + + +def get_configs(): + block_M = [64, 128] + block_N = [64, 128] + block_K = [128, 256] + num_stages = [1, 2] + threads = [128, 256] + splits = [1] + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) + + configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs] + return configs + + +def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + @tilelang.jit(out_idx=[2]) + def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + KK = K // split + + @T.prim_func + def main_split( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) + + T.clear(Ct_local) + for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): + T.copy(A[by * block_M, KK * bz + k * block_K], A_shared) + T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): + acc = T.alloc_fragment((block_N, block_M), out_dtype) + T.clear(acc) + for k in range(split): + for i, j in T.Parallel(block_N, block_M): + acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j] + T.copy(acc, Ct[bx * block_N, by * block_M]) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) + + T.clear(Ct_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct_shared) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) + + if split == 1: + return main + else: + return main_split + + if tune: + + @autotune(configs=get_configs(), warmup=10, rep=10) + @tilelang.jit(out_idx=[2]) + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func + + return kernel() + else: + + def kernel(block_M, block_N, block_K, num_stages, threads, split=1): + return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + + return kernel + + +def ref_program(A, qB): + dtypeC = T.float16 + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def main(m=256, n=256, k=256, tune=False): + total_flops = 2 * m * n * k + + if not tune: + kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune) + best_latency = best_result.latency + best_config = best_result.config + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=256, help="M") + parser.add_argument("--n", type=int, default=256, help="N") + parser.add_argument("--k", type=int, default=256, help="K") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_w4a8.py new file mode 100644 index 0000000000000000000000000000000000000000..61baa668e6eb0853c4a6c2d93a5e05dc3f254e46 --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -0,0 +1,198 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from tvm import tir +import itertools +import torch +import argparse + + +def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == T.int8 + assert val.dtype == T.uint8 + + mask = tir.const((1 << nbit) - 1, T.uint8) + + i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask + + i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8)) + i8 = i8_shifted >> tir.const(4, T.int8) + return i8 + + +def get_configs(): + iter_params = dict( + block_M=[64, 128], + block_N=[64, 128], + block_K=[128, 256], + num_stages=[1, 2], + threads=[128, 256, 512], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tilelang.jit(out_idx=[1]) +def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +def torch_convert(tensor): + def _convert(val, pos): + assert val.dtype == torch.uint8 + val = val.view(torch.int8) + mask = (1 << 4) - 1 + i4_shifted = (val >> (pos * 4)) & mask + i4 = (i4_shifted << 4) >> 4 + + return i4.view(torch.int8) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +def ref_program(A, qB): + dtypeC = T.int32 + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + @tilelang.jit(out_idx=[2]) + def kernel_func(block_M, block_N, block_K, num_stages, threads): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_local_shape = (block_N, block_K) + + assert K % (block_K) == 0 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) + + T.clear(Ct_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct_shared) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) + + return main + + if tune: + + @autotune(configs=get_configs(), warmup=10, rep=10) + @tilelang.jit(out_idx=[2]) + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func + + return kernel() + + else: + + def kernel(block_M, block_N, block_K, num_stages, threads): + return kernel_func(block_M, block_N, block_K, num_stages, threads) + + return kernel + + +def main(m=128, n=256, k=256, tune=False): + total_flops = 2 * m * n * k + if not tune: + kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) + print("All checks pass.") + + latency = profiler.do_bench(warmup=50) + print(f"Tilelang: {latency} ms") + + else: + best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune) + best_latency = best_result.latency + best_config = best_result.config + print(f"Bset latency: {best_latency}") + print(f"Best config: {best_config}") + print(f"Best tflops: {total_flops / best_latency * 1e-9}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=512, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=512, help="Matrix dimension K") + parser.add_argument("--tune", action="store_true", help="Enable tuning") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) + # main(M, N, K, True) diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/tilelang/original/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py new file mode 100644 index 0000000000000000000000000000000000000000..dea2e5ddd8a1e763cac46c1e07199b5c6077830c --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -0,0 +1,221 @@ +import tilelang +from tilelang import language as T +from typing import Optional, Callable, Any +import torch +from tilelang import DataType +from tilelang.quantize import ( + _tir_packed_int_to_int_convert, +) + + +@tilelang.jit +def dequantize_gemv( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + num_bits: int = 4, + storage_dtype: T.dtype = T.int8, + source_format: str = "uint", + n_partition: int = 4, + reduce_thread: int = 32, + fast_decoding: bool = False, + trans_A: bool = False, + trans_B: bool = True, + group_size: int = -1, + with_scaling: bool = False, +) -> Callable[..., Any]: + assert n_partition is not None, "n_partition must be provided" + assert reduce_thread is not None, ( + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + storage_type = "".join(c for c in storage_dtype if not c.isdigit()) + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + micro_size_k_compressed = micro_size_k // num_elems_per_byte + block_K = reduce_thread * micro_size_k + + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + C_shape = (M, N) + + dp4a_size = 4 + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + # Lazy import to decrease the startup time + # as intrin registry may take a while to load + from tilelang.quantize import get_lop3_intrin_group + + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=with_scaling, + with_zeros=False, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = import_source + + @T.prim_func + def main( + A: T.Tensor[A_shape, in_dtype], + B: T.Tensor[B_shape, storage_dtype], + C: T.Tensor[C_shape, out_dtype], + ): + with T.Kernel( + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), + ) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([micro_size_k], in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") + ni = T.thread_binding(0, n_partition, thread="threadIdx.y") + + T.import_source(import_source) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k_compressed): + B_quant_local[v] = B[ + bx * n_partition + ni, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, + ] + + if fast_decoding: + T.call_extern( + func_name, + T.address_of(B_quant_local[0]), + T.address_of(B_dequantize_local[0]), + dtype=in_dtype, + ) + else: + for ki in T.serial(micro_size_k): + B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype + ) + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_dequantize_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki] * B_dequantize_local[ki] + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + ) + ) + if kr == 0: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return main + + +def main() -> None: + M = 1 + N = 1024 + K = 1024 + in_dtype = T.float16 + out_dtype = T.float16 + accum_dtype = T.float16 + num_bits = 4 + storage_dtype = T.int8 + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + # int4 reference + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + for j in range(B.shape[1]): + B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("C: ", C) + print("Ref C: ", ref_c) + # doesn't apply scaling, the absolute error is large + torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/tilelang/original/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..9921c6bfe2dcc9265fe4875d21835d93baef6b90 --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -0,0 +1,522 @@ +import tilelang +import tilelang.language as T +from tilelang.quantize import _tir_u8_to_f4_to_bf16 +from tilelang import tvm as tvm +from tvm import DataType +import torch +from dequantize_utils import torch_convert_bit_twiddling, assert_similar +from tilelang.autotuner import set_autotune_inputs +import argparse + + +def get_configs(): + """ + Generate a list of hyperparameter configuration dictionaries for tuning. + + Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', + 'num_stages', 'threads', and 'split'. The function returns the Cartesian + product of the parameter value lists: + - block_M, block_N, block_K: tiling sizes + - num_stages: pipeline stages + - threads: thread counts + - split: K-splitting factor + + Returns: + List[dict]: A list of configuration dictionaries covering all combinations. + """ + import itertools + + iter_params = dict( + block_M=[128], + block_N=[64, 128, 256], + block_K=[128], + num_stages=[0, 1, 2], + threads=[128, 256, 512], + split=[1], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[-1]) +def matmul( + M, + N, + K, + topk, + E, + padding_M, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=128, + block_N=256, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. + + The generated kernel accepts: + - A: dense matrix with element type `in_dtype` and shape (M, K). + - B: packed quantized matrix for all experts, stored as uint8 with `num_bits` bits per element, shape (E, N, QK), where QK = K / (8/num_bits). + - Scale: per-expert, per-block scale/exponent information for dequantizing B, shape (E, N, K // scale_size). + - Bias: per-expert, per-output bias, shape (E, N). + - topk_weights: router weights for the top-k experts for each token, shape (M, topk). + - sorted_token_ids: flattened and padded tensor of token indices, shape (padding_M,). + - expert_ids: expert id for each token in the padded batch, shape (padding_M // block_M,). + - C: output tensor, shape (M, topk, N). + + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is (M, topk, N)). K must be divisible by (block_K * split). + topk (int): number of experts selected per token. + E (int): number of experts. + padding_M (int): padded number of tokens after grouping and block alignment. + in_dtype (str): element type of A (e.g., T.bfloat16). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the grouped, pipelined GEMM that: + - loads tiled blocks of A and packed B for each expert to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - applies per-token topk weights and bias, + - writes the final (M, topk, N) block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. + """ + + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + Bias_shared_shape = block_N + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + # the dequant part is the same as in dequant_gemm + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. + The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: + - Loads packed FP4 elements from B_shared into per-thread local registers. + - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. + - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). + - Writes the scaled BF16 results into B_dequantize_shared. + + Notes: + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. + - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k): + # import fast_dequantize plugin + """ + Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 + in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, + applying per-block scale factors from Scale. + + This routine is a tiled, thread-parallel helper that: + - Imports and calls an external dequantization function (via `import_source`/`func_name`) + to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. + - Loads the corresponding per-block scale entry, interprets it as an exponent bias + (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. + + Parameters: + - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). + - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. + - Scale_shared: per-block scale tensor; entries are interpreted such that the multiplicative scale + = 2^(Scale - 127). + - k: block index along the K dimension used to select the appropriate Scale entries. + + Side effects: + - Mutates B_dequantize_shared in shared memory. + - Calls an external intrinsic function (must be provided by the environment via `import_source` + and `func_name`) to perform the low-level unpacking/dequantization. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + Scale_local_thread = T.alloc_local((1,), storage_dtype) + Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) + + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + index_base = i * threads * local_compress_size + tx * local_compress_size + for v in T.vectorized(0, local_compress_size): + index = index_base + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + index_scale = index_base // (scale_size // num_elems_per_byte) + si = index_scale // (block_K // scale_size) + sj = index_scale % (block_K // scale_size) + Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj] + Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.Parallel(local_size): + B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale_shared[ + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 + dtype=out_dtype, + ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((E, N, QK), storage_dtype), + Scale: T.Tensor((E, N, K // scale_size), storage_dtype), + Bias: T.Tensor((E, N), out_dtype), + # Add fusedmoe tensors + topk_weights: T.Tensor((M * topk), out_dtype), + sorted_token_ids: T.Tensor((padding_M), T.int32), + expert_ids: T.Tensor((padding_M // block_M), T.int32), + C: T.Tensor((M, topk, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + topk_weights_shared = T.alloc_shared((block_M), out_dtype) + sorted_token_ids_shared = T.alloc_shared((block_M), T.int32) + expert_id = T.alloc_local((1), T.int32) # the expert id for the current block + # To use 1D TMA, the last dim of Scale_shared must have stride=1 + # May use much more shared memory than necessary + Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + T.use_swizzle(10) + + if threads == 512: + T.disable_warp_group_reg_alloc() + + T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) + expert_id[0] = expert_ids[by] + + # Get the topk weights of each token in the current block + for i in T.Parallel(block_M): + if sorted_token_ids_shared[i] != -1: + topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]] + + # Get bias and scale based on the expert id + if with_bias: + T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) + else: + T.clear(Bias_shared) + + T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) + + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = Bias_shared[j] + + tx = T.get_thread_binding() + + for k in T.Pipelined(K // block_K, num_stages=num_stages): + # Each thread copies 4 bytes, local size is 16 + for copy_i in T.serial(block_M * block_K // threads // 16): + base = copy_i * threads * 16 + tx * 16 + if sorted_token_ids_shared[base // block_K] != -1: + for copy_j in T.vectorized(16): + A_shared[base // block_K, base % block_K + copy_j] = A[ + sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j + ] + + T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = C_local[i, j] * topk_weights_shared[i] + + T.copy(C_local, C_shared) + for copy_i in T.serial(block_M * block_N // threads // 16): + base = copy_i * threads * 16 + tx * 16 + if sorted_token_ids_shared[base // block_N] != -1: + for copy_j in T.vectorized(16): + C[ + sorted_token_ids_shared[base // block_N] // topk, + sorted_token_ids_shared[base // block_N] % topk, + bx * block_N + base % block_N + copy_j, + ] = C_shared[base // block_N, base % block_N + copy_j] + + return main + + +def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): + dtypeC = T.bfloat16 + M, K = A.shape + E, N, QK = qB.shape + topk = topk_weights.shape[0] // M + scale_size = K // Scale.shape[2] + assert scale_size == 32 # MXFP4 + + # Initialize output tensor + C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") + + # Iterate over sorted_token_ids + for idx in range(len(sorted_token_ids)): # padding_M + token_id = sorted_token_ids[idx] + if token_id == -1: + continue + expert_id = expert_ids[idx // block_M] + topk_idx = token_id % topk + + # Get the token embedding + token_embedding = A[token_id // topk] + + # Dequantize the expert weights + B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) + B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) + + # Compute the output for this token-expert pair + # token_embedding @ B.T + bias + output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] + output = output.to(torch.__getattribute__(dtypeC)) + + # Apply the topk weight + weight = topk_weights[token_id] + output = output * weight + + # Store the result + C[token_id // topk, topk_idx] = output + + return C + + +def get_data(m, n, k, qk, scale_size, topk, E, block_M): + A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. + Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") + Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + + weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + # topk_weights: Router weights for the top-k experts for each token. + # Shape: (m, topk) + # tokens_experts: A flattened tensor of expert assignments for each token. + # For each of m tokens, topk unique experts are chosen. Shape: (m * topk,) + topk_weights, tokens_experts = torch.topk(weights, topk, dim=-1) + tokens_experts = tokens_experts.reshape(m * topk) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.reshape(m * topk) + + sorted_expert_vals, sorted_indices = torch.sort(tokens_experts, stable=True) + sorted_token_ids = sorted_indices + unique_expert_ids, counts = torch.unique_consecutive(sorted_expert_vals, return_counts=True) + expert_ids = [] + padded_token_ids = [] + start = 0 + for eid, cnt in zip(unique_expert_ids.tolist(), counts.tolist()): + end = start + cnt + group_token_ids = sorted_token_ids[start:end] + pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt + if pad_len > 0: + # -1 for padding (`M` instead in vLLM moe_align_block_size()) + group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) + padded_token_ids.append(group_token_ids) + expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) + start = end + + # sorted_token_ids: The final flattened and padded tensor of token indices. + sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) + # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) + padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding + + return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M + + +def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): + # Tunable parameters + block_M, block_N, block_K = 128, 256, 128 # noqa: F841 + num_stages = 1 # noqa: F841 + threads = 512 # noqa: F841 + split = 1 # noqa: F841 + + total_flops = 2 * m * n * k * topk + num_bits = 4 + num_elems_per_byte = 8 // num_bits + qk = k // num_elems_per_byte + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) + + if tune: + with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): + # Autotune with inputs manually composed + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + else: + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, + ) + + output = kernel( + A, + qB, + Scale, + Bias, + topk_weights, + sorted_token_ids, + expert_ids, + ) + + print("Tilelang kernel run finished.") + + ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... + + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + diff = (output - ref_output).abs() + max_val = diff.max() + max_idx = diff.argmax() + print(f"max abs diff: {max_val} at index: {max_idx}") + assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference + print("All checks pass. ✅") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--N", type=int, default=5760, help="N") + parser.add_argument("--K", type=int, default=2944, help="K") + parser.add_argument("--scale_size", type=int, default=32, help="scale size") + parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token + parser.add_argument("--E", type=int, default=32, help="E") # number of experts + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + + main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/tilelang/original/examples/dequantize_gemm/test_example_dequantize_gemm.py b/tilelang/original/examples/dequantize_gemm/test_example_dequantize_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..01bc40e6c944b9c4b1f9fff36da132b002055a2b --- /dev/null +++ b/tilelang/original/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -0,0 +1,47 @@ +import tilelang.testing + +import example_dequant_gemv_fp16xint4 +import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper_tma +import example_dequant_groupedgemm_bf16_mxfp4_hopper +import example_dequant_gemm_w4a8 + + +@tilelang.testing.requires_cuda +def test_example_dequant_gemv_fp16xint4(): + example_dequant_gemv_fp16xint4.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_fp4_hopper(): + example_dequant_gemm_fp4_hopper.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_bf16_mxfp4_hopper(): + example_dequant_gemm_bf16_mxfp4_hopper.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): + example_dequant_gemm_bf16_mxfp4_hopper_tma.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + example_dequant_groupedgemm_bf16_mxfp4_hopper.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_w4a8(): + example_dequant_gemm_w4a8.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/dsa_sparse_finetune/dsa.py b/tilelang/original/examples/dsa_sparse_finetune/dsa.py new file mode 100644 index 0000000000000000000000000000000000000000..9fae8e5e3d698c9d7763b707fa2b2fd7506257c2 --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/dsa.py @@ -0,0 +1,223 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from indexer_topk_reducesum import indexer_topk_reducesum_interface +from indexer_bwd import indexer_bwd_interface +from sparse_mla_fwd import sparse_mla_fwd_interface +from sparse_mla_bwd import sparse_mla_bwd +from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface +from einops import einsum, repeat +from utils import get_abs_err, get_err_ratio + + +class RegsiterLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.save_for_backward(loss) + return x + + @staticmethod + def backward(ctx, grad): + loss = ctx.saved_tensors + return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device) + + +register_loss = RegsiterLossFunction.apply + + +def ref_deepseek_sparse_attention_innner( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + dtype = q.dtype + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights)) + + index_sm_scale = index_q.shape[-1] ** -0.5 + b, s = index_q.shape[:2] + + # tl_topk_indices = tl_topk_indices.to(torch.int64) + # tl_topk_indices[tl_topk_indices == -1] = s + + casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2") + index_logits = F.relu(index_logits) + index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float("-inf")) + topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices + topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices) + topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) + index_topk_score = topk_score + + if sm_scale is None: + sm_scale = kv.shape[-1] ** -0.5 + + h = q.shape[-2] + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_( + dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool) + )[:, :, :-1] + mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h) + k, v = kv, kv[..., :dim_v] + logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d") + + attn_score = attn_score.sum(dim=-2) # [b, s1, s2] + attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) + attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) + + loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum") + o = register_loss(o, loss) + + return o.to(dtype), topk_indices + + +def ref_deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + all_o, all_topk_indices = [], [] + for i in range(offsets.shape[0] - 1): + o, topk_indices = ref_deepseek_sparse_attention_innner( + q[None, offsets[i] : offsets[i + 1]], + kv[None, offsets[i] : offsets[i + 1]], + index_q[None, offsets[i] : offsets[i + 1]], + index_k[None, offsets[i] : offsets[i + 1]], + weights[None, offsets[i] : offsets[i + 1]], + topk, + dim_v, + sm_scale, + index_sm_scale, + ) + all_o.append(o.squeeze(0)) + all_topk_indices.append(topk_indices.squeeze(0)) + o = torch.cat(all_o, dim=0) + topk_indices = torch.cat(all_topk_indices, dim=0) + return o, topk_indices + + +class DSAFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + ): + # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets) + o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets) + ctx.topk = topk + ctx.dim_v = dim_v + ctx.sm_scale = sm_scale + return o, topk_indices + + @staticmethod + def backward( + ctx, + do: torch.Tensor, + _1: torch.Tensor, + ): + q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors + attn_score = sparse_mla_topk_reducesum_interface( + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v + ).squeeze(-2) + dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets) + return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None + + +def deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, +): + return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale) + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + index_D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_() + index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_() + weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_() + index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_() + do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_() + offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda() + + o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + o.backward(do) + q_grad, q.grad = q.grad, None + kv_grad, kv.grad = kv.grad, None + index_q_grad, index_q.grad = index_q.grad, None + index_k_grad, index_k.grad = index_k.grad, None + weights_grad, weights.grad = weights.grad, None + + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + ref_o.backward(do) + ref_q_grad, q.grad = q.grad, None + ref_kv_grad, kv.grad = kv.grad, None + ref_index_q_grad, index_q.grad = index_q.grad, None + ref_index_k_grad, index_k.grad = index_k.grad, None + ref_weights_grad, weights.grad = weights.grad, None + + print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") + print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}") + print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}") + print( + f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" + ) + print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}") + print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}") + + intersections = [] + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + mask = trt_np != -1 + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + intersections.append(len(intersection) / len(set_ref)) + print("average intersections: {:.4f}".format(sum(intersections) / len(intersections))) + + +test_kernel() diff --git a/tilelang/original/examples/dsa_sparse_finetune/index.py b/tilelang/original/examples/dsa_sparse_finetune/index.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4800411004e5890faba0578cf83f09e27f2dc9 --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/index.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) diff --git a/tilelang/original/examples/dsa_sparse_finetune/indexer_bwd.py b/tilelang/original/examples/dsa_sparse_finetune/indexer_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..68508ad4e45104b3b5717c95ef30ebfe1caaccd4 --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/indexer_bwd.py @@ -0,0 +1,254 @@ +import torch +import torch.nn.functional as F +from einops import einsum, repeat + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_I: int = 32, + num_stages: int = 0, + num_threads: int = 128, +): + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [seq_len, topk] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos = Offsets[i_b] + num_blocks = T.ceildiv(topk, block_I) + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + weights_shared = T.alloc_shared([heads], dtype=dtype) + + d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype) + + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.copy(Weights[bos + i_t, :], weights_shared) + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + + for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + i_st = bi_i * block_I + i_ed = (bi_i + 1) * block_I + + indices_shared = T.alloc_shared([block_I], dtype=INT32) + T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared) + + index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0) + + attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + for i in T.Parallel(block_I): + attn_score_shared[i] = AttnScore[bos + i_t, i_st + i] + index_score_shared[i] = IndexScore[bos + i_t, i_st + i] + + logits = T.alloc_fragment((block_I, heads), accum_dtype) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + # dw + d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) + d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype) + d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype) + + for i, j in T.Parallel(block_I, heads): + d_relu = T.alloc_var(accum_dtype) + if logits[i, j] > 0: + d_relu = 1.0 + else: + d_relu = 0.0 + d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j] + + # dq + T.copy(d_logits_qk, d_logits_qk_cast1) + T.gemm( + d_logits_qk_cast1, # [BS, HQ] + index_k_shared, # [BS, K] + d_index_q_frag, # [HQ, K] + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + # dk + T.copy(d_logits_qk, d_logits_qk_cast2) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + T.gemm( + d_logits_qk_cast2, # [BS, HQ] + index_q_shared, # [HQ, K] + d_index_k_frag, # [BS, K] + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + if (pos > -1) & (pos <= i_t): + T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) + + for i, j in T.Parallel(heads, dim): + d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale + + T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :]) + T.copy(d_weights_frag, dWeights[bos + i_t, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + attn_score: torch.Tensor, + index_score: torch.Tensor, + topk_indices: torch.Tensor, + offsets: torch.Tensor, +): + _, heads, dim, topk = *q.shape, topk_indices.shape[-1] + token_indices = prepare_token_indices(offsets) + dq = torch.zeros_like(q) + dweights = torch.zeros_like(weights) + dk = torch.zeros_like(k) + kernel = tl_indexer_bwd_impl(heads, dim, topk) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices) + return dq, dweights, dk + + +def ref_indexer_bwd( + Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor +) -> torch.Tensor: + Q.requires_grad_(True) + Weights.requires_grad_(True) + K.requires_grad_(True) + softmax_scale = Q.shape[-1] ** -0.5 + all_loss = [] + all_log_topk_prob = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + attn_score = AttnScore[offsets[i] : offsets[i + 1]] + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale + logits = F.relu(logits) + score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) + score = torch.where(mask, score, float("-inf")) + topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) + log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) + loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum") + all_loss.append(loss) + all_log_topk_prob.append(log_topk_prob) + loss = torch.stack(all_loss).sum() + loss.backward() + log_topk_prob = torch.cat(all_log_topk_prob, dim=0) + return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad + + +def test_kernel( + B=1, + S=2048, + H=16, + D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D)).cuda().bfloat16() + w = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + all_attn_score = [] + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) + logits = torch.ones(seq_len, topk).cuda() + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + all_attn_score.append(attn_score) + attn_score = torch.cat(all_attn_score, dim=0) + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets) + + dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) + + print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}") + print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}") + print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/tilelang/original/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/tilelang/original/examples/dsa_sparse_finetune/indexer_topk_reducesum.py new file mode 100644 index 0000000000000000000000000000000000000000..d76eb027247b9ce8fdf4cd20f422d7a79304eb3b --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -0,0 +1,273 @@ +import math +import torch +import torch.nn.functional as F +from einops import einsum + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_topk_reducesum_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_K: int = 32, + dtype: str = FP32, + num_stages: int = 0, + num_threads: int = 128, +): + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_K == 0 + assert heads <= 64 and heads % 8 == 0 + assert num_stages == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + N = 2 * topk + num_iters = int(round(math.log2(N))) + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.macro + def bitonic_sort( + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), + ): + T.sync_threads() + for i1 in T.serial(num_iters): + for i2 in T.serial(i1 + 1): + for i in T.Parallel(N): + ascending = (i & (1 << (i1 + 1))) != 0 + j = i ^ (1 << (i1 - i2)) + if i < j and ( + (ascending and topk_value_shared[i] > topk_value_shared[j]) + or (not ascending and topk_value_shared[i] < topk_value_shared[j]) + ): + val = topk_value_shared[i] + topk_value_shared[i] = topk_value_shared[j] + topk_value_shared[j] = val + idx = topk_index_shared[i] + topk_index_shared[i] = topk_index_shared[j] + topk_index_shared[j] = idx + T.sync_threads() + + @T.prim_func + def tl_indexer_topk_reducesum_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos, eos = Offsets[i_b], Offsets[i_b + 1] + num_blocks = T.ceildiv(i_t + 1, block_K) + + topk_index_shared = T.alloc_shared([N], dtype=INT32) + topk_value_shared = T.alloc_shared([N], dtype=FP32) + + T.fill(topk_index_shared, -1) + T.fill(topk_value_shared, float("-inf")) + T.sync_threads() + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.sync_threads() + + weights_frag = T.alloc_shared([heads], dtype=dtype) + T.copy(Weights[bos + i_t, :], weights_frag) + T.sync_threads() + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + T.sync_threads() + + for bk_i in T.Pipelined(num_blocks, num_stages=num_stages): + k_st = bk_i * block_K + k_ed = T.min((bk_i + 1) * block_K, eos - bos) + + index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) + for i, j in T.Parallel(block_K, dim): + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0) + T.sync_threads() + + logits = T.alloc_fragment((block_K, heads), FP32) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + T.sync_threads() + + for i, j in T.Parallel(block_K, heads): + logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j] + T.sync_threads() + + logits_sum = T.alloc_fragment(block_K, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + T.sync_threads() + + offset = T.alloc_var(INT32) + if k_st >= topk: + offset = topk + (k_st % topk) + else: + offset = k_st + T.sync_threads() + for i in T.Parallel(block_K): + if k_st + i > i_t: + logits_sum[i] = float("-inf") + j = offset + i + topk_index_shared[j] = k_st + i + topk_value_shared[j] = logits_sum[i] + T.sync_threads() + + if k_ed > topk and k_ed % topk == 0: + bitonic_sort(topk_index_shared, topk_value_shared) + + bitonic_sort(topk_index_shared, topk_value_shared) + + logits_max_frag = T.alloc_fragment([1], dtype=FP32) + logits_frag = T.alloc_fragment([topk], dtype=FP32) + reducesum_shared = T.alloc_shared([topk], dtype=FP32) + + T.copy(topk_value_shared[:topk], logits_frag) + T.sync_threads() + + T.reduce_max(logits_frag, logits_max_frag, dim=-1) + T.sync_threads() + + for i in T.Parallel(topk): + logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0]) + T.sync_threads() + + lse_frag = T.alloc_fragment([1], dtype=FP32) + T.reduce_sum(logits_frag, lse_frag) + T.sync_threads() + + for i in T.Parallel(topk): + reducesum_shared[i] = logits_frag[i] / lse_frag[0] + T.sync_threads() + + # for i in T.Parallel(topk): + # reducesum_shared[i] = logits_frag[i] + # T.sync_threads() + + for i in T.Parallel(topk): + if topk_index_shared[i] > i_t: + topk_index_shared[i] = -1 + T.sync_threads() + + T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :]) + T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :]) + + return tl_indexer_topk_reducesum_kernel + + +def indexer_topk_reducesum_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + topk: int, + offsets: torch.Tensor, + dtype: str = BF16, +): + seq_len, heads, dim = q.shape + kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype) + token_indices = prepare_token_indices(offsets) + topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32) + topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32) + kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices) + return topk_indices, topk_score + + +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor: + all_topk_indices = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= topk + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + softmax_scale = q.shape[-1] ** -0.5 + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") + logits = F.relu(logits) + logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale + logits = torch.where(mask, logits, float("-inf")) + topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) + topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) + all_topk_indices.append(topk_indices) + all_topk_score.append(topk_score) + topk_indices = torch.cat(all_topk_indices, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return topk_indices, topk_score + + +def test_kernel( + B=1, + S=2048, + H=64, + D=128, + topk=64, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D)).cuda().bfloat16() + weights = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, S], dtype=torch.int32).cuda() + + ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets) + + topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets) + + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + ref_np_val = ref_topk_score[j] + trt_np_val = topk_score[j] + + mask = (ref_np_val > 0).cpu().numpy() + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8b76dbca1c5fa483f57399e701a17bff870edd80 --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -0,0 +1,354 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + S = T.symbolic("S") + + shape = [S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + S_kv = T.symbolic("S_kv") + + dkv_shape = [S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): + T.copy( + dKV[bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def bwd( + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + + B_plus_one = T.symbolic("B_plus_one") + S = T.symbolic("S") + + H_kv = H // kv_group + q_shape = [S, H, D + D_tail] + k_shape = [S, kv_group, D + D_tail] + o_shape = [S, H, D] + indices_shape = [S, kv_group, topk] + delta_shape = [S, H] + lse_shape = [S, H] + offsets_shape = [B_plus_one] + token_indices_shape = [S, 2] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) + acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + + max_kv_i = s_i + + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] + + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + S, H, dim_plus_tail_dim = q.shape + S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert S == S_kv + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (S, kv_group, topk) + assert lse.shape == (S, H) + + token_indices = prepare_token_indices(offsets) + + # Get kernels + preprocess_kernel = preprocess(H, D) + bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True): + # Prepare data + q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device="cuda") + offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True) diff --git a/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d87523695240ce3029c29e84c10c50cbfc4a39c8 --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -0,0 +1,310 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + else: + sm_scale = sm_scale + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + head_kv = heads // kv_group + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len, kv_group, dim + tail_dim] + o_shape = [seq_len, heads, dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[bos + s_i, H0:H1, :]) + T.copy(sumexp, Lse[bos + s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128 +): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + seq_len, heads, dim_plus_tail_dim = q.shape + seq_len_kv, kv_group, _ = kv.shape + assert seq_len == seq_len_kv + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + _, _, topk = indices.shape + assert indices.shape == (seq_len, kv_group, topk) + + token_indices = prepare_token_indices(offsets) + + kernel = sparse_mla_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) + out, lse = kernel(q, kv, indices, offsets, token_indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True): + Q = Q.float() + KV = KV.float() + all_o = [] + for i in range(offsets.shape[0] - 1): + q = Q[None, offsets[i] : offsets[i + 1]] + kv = KV[None, offsets[i] : offsets[i + 1]] + indices = Indices[None, offsets[i] : offsets[i + 1]].clone() + + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) + + indices[indices > sk] = sk + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + all_o.append(o.squeeze(0)) + o = torch.cat(all_o, dim=0) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): + torch.random.manual_seed(0) + q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=1024, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, + ) diff --git a/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py new file mode 100644 index 0000000000000000000000000000000000000000..a03bc74f51e254b8cd9232eebc91bc9c6f0fa4c9 --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -0,0 +1,226 @@ +# ruff: noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +import tilelang +from tilelang import language as T +from einops import repeat, rearrange, einsum +from index import prepare_token_indices +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tilelang.jit(pass_configs=pass_configs) +def tl_sparse_mla_topk_reducesum_impl( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = heads // kv_group + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len_kv, kv_group, dim + tail_dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + @T.prim_func + def tl_sparse_mla_topk_reducesum_kernel( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + reducesum = T.alloc_fragment([BI], accum_dtype) + lse = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(lse, 0) + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + r_i = bx % REPLICATE_H + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + T.copy(Lse[bos + s_i, H0:H1], lse) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) + T.reduce_sum(acc_s, reducesum, dim=0) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI]) + + return tl_sparse_mla_topk_reducesum_kernel + + +def sparse_mla_topk_reducesum_interface( + q: torch.Tensor, + kv: torch.Tensor, + topk_indices: torch.Tensor, + lse: torch.Tensor, + offsets: torch.Tensor, + dim_v: int, +): + assert kv.shape[-2] == 1 + seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1] + REPLICATE_H = max(heads // 64, 1) + tail_dim = dim_plus_tail_dim - dim_v + token_indices = prepare_token_indices(offsets) + + reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device) + kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk) + kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum) + reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk] + attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True) + + return attn_score + + +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor): + # q: [batch, seq_len, heads, dim] + # k: [batch, seq_len, dim] + sm_scale = Q.shape[-1] ** -0.5 + all_lse = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + q = Q[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + seq_len = q.shape[0] + mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + score = F.softmax(logits, dim=-1, dtype=torch.float32) + score_sum = score.sum(dim=-2) + topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) + topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) + max_logits = logits.amax(dim=-1).to(torch.float32) + lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + all_lse.append(lse) + all_topk_score.append(topk_score) + lse = torch.cat(all_lse, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return lse, topk_score + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + topk=128, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + + lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) + + kv = kv.unsqueeze(-2) + topk_indices = topk_indices.unsqueeze(-2) + + attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/tilelang/original/examples/dsa_sparse_finetune/utils.py b/tilelang/original/examples/dsa_sparse_finetune/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96afd064dc0f83f0e813fa4093f10d2fd309dfce --- /dev/null +++ b/tilelang/original/examples/dsa_sparse_finetune/utils.py @@ -0,0 +1,73 @@ +import torch + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 diff --git a/tilelang/original/examples/elementwise/example_elementwise_add.py b/tilelang/original/examples/elementwise/example_elementwise_add.py new file mode 100644 index 0000000000000000000000000000000000000000..f075c64fd669ff4ce15ae518b3e17998aad8edae --- /dev/null +++ b/tilelang/original/examples/elementwise/example_elementwise_add.py @@ -0,0 +1,62 @@ +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T + + +def ref_program(x, y): + return x + y + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[-1]) +def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): + @T.prim_func + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), in_dtype) + B_shared = T.alloc_shared((block_M, block_N), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) + for local_y, local_x in T.Parallel(block_M, block_N): + C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return elem_add + + +def main(M=1024, N=1024, use_autotune=False): + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if use_autotune: + kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32) + else: + # Default config + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) + + out = kernel(a, b) + torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + main(args.m, args.n, args.use_autotune) diff --git a/tilelang/original/examples/elementwise/test_example_elementwise.py b/tilelang/original/examples/elementwise/test_example_elementwise.py new file mode 100644 index 0000000000000000000000000000000000000000..24f675cd6a3778280ce1a52c1b6e6ca54aa8393c --- /dev/null +++ b/tilelang/original/examples/elementwise/test_example_elementwise.py @@ -0,0 +1,14 @@ +import tilelang.testing +import example_elementwise_add + + +def test_example_elementwise_add(): + example_elementwise_add.main() + + +def test_example_elementwise_add_autotune(): + example_elementwise_add.main(use_autotune=True) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/flash_attention/README.md b/tilelang/original/examples/flash_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..633727ec4e9270b66176db82d3e13f430895c33a --- /dev/null +++ b/tilelang/original/examples/flash_attention/README.md @@ -0,0 +1,111 @@ +# FlashAttention + +Using tile-lang, we can define buffers at different memory layers. For instance, `Q_shared`, `K_shared`, and `V_shared` can be defined in shared memory, while `acc_s` and `acc_o` can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way. + +```python +@T.prim_func +def flash_attention( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), +): + # Launch a specialized T.Kernel with 3D mapping: (bx, by, bz) + # bx: block index in sequence dimension + # by: block index in "heads" dimension + # bz: block index in "batch" dimension + # threads=thread_num means how many threads per block + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): + # Allocate shared memory for Q, K, V to reduce global memory accesses + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + # Allocate buffers on register + # acc_s: buffer to hold intermediate attention scores + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + # acc_s_cast: buffer for storing casted/adjusted scores + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + # acc_o: partial accumulation of output + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + # Buffers to track per-row maximum score and related stats + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + # Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + + # Copy a block of Q from global memory to Q_shared + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + + # Initialize accumulators + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = ( + T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + ) + + # Pipeline the loop to overlap copies/gemm stages + for k in T.Pipelined(loop_range, num_stages=num_stages): + # Copy K block into shared memory + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) + ) + else: + T.clear(acc_s) + + # Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed, + # policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row + # of acc_s, and the resulting acc_s is retained in registers. + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Copy V block into shared memory + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale + + # Save old scores_max, then reset scores_max + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # Compute the maximum value per row on dimension 1 (block_N) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Compute the factor by which we need to rescale previous partial sums + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + + # Rescale the partial output accumulation to keep exponents consistent + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + # Exponentiate (scores - max) for the new block + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + + # Make a cast of acc_s to fp16 for the next GEMM + T.copy(acc_s, acc_s_cast) + + # Multiply the attention acc_s_cast by V and add to partial output (acc_o) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + # Update the "logsum" tracker with the newly accumulated sum + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + # Final step: divide each partial output by logsum (completing the softmax) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + + # Write back the final output block from acc_o to the Output buffer + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) +``` \ No newline at end of file diff --git a/tilelang/original/examples/flash_attention/bert_padding.py b/tilelang/original/examples/flash_attention/bert_padding.py new file mode 100644 index 0000000000000000000000000000000000000000..15c4097ce77a21ebcd2060b53c629e7a89972b88 --- /dev/null +++ b/tilelang/original/examples/flash_attention/bert_padding.py @@ -0,0 +1,205 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py +# ruff: noqa +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): + """ + Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). + The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). + + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) + real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() + seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/tilelang/original/examples/flash_attention/example_gqa_bwd.py b/tilelang/original/examples/flash_attention/example_gqa_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..89c1166693c672a8fd0021419837c950a0651df9 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_bwd.py @@ -0,0 +1,514 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_qk] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel + dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + + T.copy(dv, dv_shared) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk, dk_shared) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + ctx.use_atomic = use_atomic + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD_QK = q.shape + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] + groups = H // HEAD_KV + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK) + delta = mod_prep(o, do) + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + else: + kernel = flashattn_bwd_split( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel + shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) + dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk, dv = dk.sum(0), dv.sum(0) + + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups, use_atomic) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") + args = parser.parse_args() + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/tilelang/original/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/tilelang/original/examples/flash_attention/example_gqa_bwd_tma_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..07586f99fdd7d55d52a59359337c423dcca96a6f --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -0,0 +1,535 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.contrib import nvcc +import argparse + +tilelang.disable_cache() + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops + # We should set it to negative large number instead + T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # bshd -> bhld to use tma reduction instruction + return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d]) + + +@tilelang.jit( + out_idx=[3, 4, 5], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :]) + with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :]) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.copy(dq, dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) + + return flash_bwd + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel + dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + + T.copy(dv, dv_shared) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk, dk_shared) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + ctx.use_atomic = use_atomic + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD_QK = q.shape + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] + groups = H // HEAD_KV + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) + delta = mod_prep(o, do) + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq, dk, dv = mod_post(dq, dk, dv) + else: + kernel = flashattn_bwd_split_novarlen( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel + shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) + dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) + dk, dv = dk.sum(0), dv.sum(0) + + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups, use_atomic) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + arch = nvcc.get_target_compute_version() + print(f"Detected GPU compute capability: {arch}") + assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") + args = parser.parse_args() + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/tilelang/original/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/tilelang/original/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..cc88b64da7a44ffc9b95f09bb8a1cd45eb681136 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -0,0 +1,730 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.contrib import nvcc +import argparse +from einops import rearrange, repeat +from bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +@tilelang.jit( + out_idx=[5, 6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + o_shape = [total_q, heads, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + k_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + + for i, d in T.Parallel(block_M, dim_qk): + if bx * block_M + i < q_current_seqlen: + Q_shared[i, d] = Q[q_start_idx + bx * block_M + i, by, d] + else: + Q_shared[i, d] = 0.0 + + T.fill(acc_o, 0.0) + T.fill(logsum, 0.0) + # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops + # We should set it to negative large number instead + T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + loop_range = T.ceildiv(k_current_seqlen, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + for i, d in T.Parallel(block_N, dim_qk): + if k * block_N + i < k_current_seqlen: + K_shared[i, d] = K[k_start_idx + k * block_N + i, by // groups, d] + else: + K_shared[i, d] = 0.0 + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), + 0, + T.Cast(accum_dtype, -1e30), + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30) + ) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, d in T.Parallel(block_N, dim_v): + if k * block_N + i < k_current_seqlen: + V_shared[i, d] = V[k_start_idx + k * block_N + i, by // groups, d] + else: + V_shared[i, d] = 0.0 + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + + for i, d in T.Parallel(block_M, dim_v): + if bx * block_M + i < q_current_seqlen: + Output[q_start_idx + bx * block_M + i, by, d] = acc_o[i, d] + + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + if bx * block_M + i < q_current_seqlen: + lse[bz, by, bx * block_M + i] = logsum[i] + + return flash_fwd + + +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [total_q, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + q_end_idx = cu_seqlens_q[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + for i, j in T.Parallel(blk, blk): + if by * blk + i < q_current_seqlen and k * blk + j < dim_v: + o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j] + do[i, j] = dO[q_start_idx + by * blk + i, bx, k * blk + j] + else: + o[i, j] = 0.0 + do[i, j] = 0.0 + + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + + for i in T.Parallel(blk): + if by * blk + i < q_current_seqlen: + Delta[bz, bx, by * blk + i] = delta[i] + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # bshd -> bhsd to use tma reduction instruction + return T.Layout(dQ.shape, lambda l, h, d: [h, l, d]) + + +@tilelang.jit( + out_idx=[3, 4, 5], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :]) + with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :]) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + do_shape = [total_q, heads, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + k_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) + + T.clear(dv) + T.clear(dk) + + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 + loop_ed = T.ceildiv(q_current_seqlen, block_N) + + for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) + + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) + + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) + T.clear(dsT) + # dsT: (block_kv, block_q) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.copy(dq, dq_shared) + T.atomic_add( + dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :], + dq_shared, + memory_order="relaxed", + use_tma=True, + ) + + T.copy(dv, dv_shared) + T.atomic_add( + dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], + dv_shared, + memory_order="relaxed", + use_tma=True, + ) + T.copy(dk, dk_shared) + T.atomic_add( + dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], + dk_shared, + memory_order="relaxed", + use_tma=True, + ) + + return flash_bwd + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + do_shape = [total_q, heads, dim_v] + dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel + dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], dtype) + + q_start_idx = cu_seqlens_q[bz] + k_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) + + T.clear(dv) + T.clear(dk) + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 + loop_ed = T.ceildiv(q_current_seqlen, block_N) + + for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + # Note: The padding zero of varlen should be considered in T.copy + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) + + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) + + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + if k_base * block_N + i < q_current_seqlen: + T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed") + + T.copy(dv, dv_shared) + T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) + T.copy(dk, dk_shared) + T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward( + ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True + ): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + + total_q = q_unpad.shape[0] + total_kv = k_unpad.shape[0] + + mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) + o = pad_input(o_unpad, indices_q, BATCH, N_CTX) + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) + ctx.batch = BATCH + ctx.causal = causal + ctx.use_atomic = use_atomic + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.indices_q = indices_q + ctx.indices_k = indices_k + return o + + @staticmethod + def backward(ctx, do): + N_CTX = do.shape[1] + q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + # lse_clone = lse.clone() + do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + total_q, H, D_HEAD_QK = q.shape + total_kv, HEAD_KV, D_HEAD_V = v.shape + groups = H // HEAD_KV + BATCH = len(cu_seqlens_q) - 1 + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) + mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) + delta = mod_prep(o, do, cu_seqlens_q) + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add( + BATCH, + total_q, + total_kv, + N_CTX, + H, + ctx.max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups, + ) + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.zeros_like(k, dtype=torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) + kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq, dk, dv = mod_post(dq, dk, dv) + else: + kernel = flashattn_bwd_split( + BATCH, + total_q, + total_kv, + N_CTX, + H, + ctx.max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups, + ) + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) + dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) + dk, dv = dk.sum(0), dv.sum(0) + + dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) + dk = pad_input(dk, ctx.indices_k, BATCH, N_CTX) + dv = pad_input(dv, ctx.indices_k, BATCH, N_CTX) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, padding_mask, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + # To handle precision issue + Q, K, V = Q.float(), K.float(), V.float() + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if padding_mask is not None: + scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + if padding_mask is not None: + output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) + return output + + +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") + seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) + cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) + max_seqlen_q = seqlens_q.max().item() + + # In training backward pass, seqlens_k should be the same as seqlens_q + seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q + + O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, padding_mask, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + print( + "Note: this varlen kernel performance is as good as the non-varlen kernel shown in Nsight-Compute. As you may observe that the TFLOPS is a bit lower, that's because the unpad operation is included in the above benchmark." + ) + + +if __name__ == "__main__": + arch = nvcc.get_target_compute_version() + print(f"Detected GPU compute capability: {arch}") + assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") + args = parser.parse_args() + # Can be set to True/False for testing + args.causal = True + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/tilelang/original/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/tilelang/original/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e2de27752cee99c77fab091b599a4de5e65928 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -0,0 +1,353 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.wait_wgmma(1) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.wait_wgmma(0) + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) + T.wait_wgmma(0) + T.copy(dq, dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + ctx.use_atomic = use_atomic + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD_QK = q.shape + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] + groups = H // HEAD_KV + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + delta = mod_prep(o, do) + + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = dq.to(torch.float16) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + args = parser.parse_args() + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) diff --git a/tilelang/original/examples/flash_attention/example_gqa_fwd_bshd.py b/tilelang/original/examples/flash_attention/example_gqa_fwd_bshd.py new file mode 100644 index 0000000000000000000000000000000000000000..5005435eaf7cb2349f09d8900a4718e9f84f52e6 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_fwd_bshd.py @@ -0,0 +1,256 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +class FlashAttentionTuneSpace: + def __init__( + self, + block_sizes=(64, 128, 256), + thread_options=(128, 256, 512), + num_stages_range=(2, 3), + max_shared_mem=100 * 1024, + warp_alignment=16, + dim=128, + dtype_bytes=2, + ): + self.block_sizes = block_sizes + self.thread_options = thread_options + self.num_stages_range = num_stages_range + self.max_shared_mem = max_shared_mem + self.warp_alignment = warp_alignment + self.dim = dim + self.dtype_bytes = dtype_bytes + + +def get_configs(user_config=None): + config = user_config or FlashAttentionTuneSpace() + valid_configs = [] + + for block_M, block_N in itertools.product(config.block_sizes, repeat=2): + for threads in config.thread_options: + assert threads % 32 == 0 + warp_count = threads // 32 + warp_M = block_M // warp_count + warp_N = block_N // warp_count + + if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: + continue + + shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) + if shared_mem > config.max_shared_mem: + continue + + for num_stages in config.num_stages_range: + valid_configs.append( + { + "block_M": block_M, + "block_N": block_N, + "num_stages": num_stages, + "threads": threads, + } + ) + return valid_configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D] + # K: [B, T, HK, D] + # V: [B, T, HV, D] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/tilelang/original/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/tilelang/original/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7a71b1780ba0ea370d01f5f72379e59491f76e --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -0,0 +1,243 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict( + block_M=[128], + block_N=[128], + num_stages=[2], + threads=[256], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_configs(), + warmup=10, + rep=10, +) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=1, + block_M=64, + block_N=64, + num_stages=0, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D] + # K: [B, T, HK, D] + # V: [B, T, HV, D] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/tilelang/original/examples/flash_attention/example_gqa_fwd_varlen.py b/tilelang/original/examples/flash_attention/example_gqa_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..b02345d93084843074d0924a4e945424bf104ca7 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_gqa_fwd_varlen.py @@ -0,0 +1,253 @@ +# ruff: noqa +import argparse +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from einops import rearrange, repeat +from tilelang.profiler import do_bench +from varlen_utils import generate_random_padding_mask, generate_qkv + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), + upcast=True, +): + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + b, T, Hq, D = q.shape + S = k.shape[1] + scale = (1.0 / D) ** 0.5 + k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) + scores = torch.einsum("bthd,bshd->bhts", q, k) + left, right = window_size + left = S if left is None or left < 0 else int(left) + right = S if right is None or right < 0 else int(right) + t_idx = torch.arange(T, device=scores.device)[:, None] + s_idx = torch.arange(S, device=scores.device)[None, :] + visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right)) + visible_mask = visible_ts.unsqueeze(0).unsqueeze(0) + if key_padding_mask is not None: + k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s") + visible_mask = visible_mask & k_keep + neg_inf = torch.finfo(scores.dtype).min + scores = scores * scale + scores = scores.masked_fill(~visible_mask, neg_inf) + attention = torch.softmax(scores, dim=-1).to(v.dtype) + if query_padding_mask is not None: + q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1") + attention = attention.masked_fill(~q_keep, 0.0) + output = torch.einsum("bhts,bshd->bthd", attention, v) + if query_padding_mask is not None: + output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + return main + + +def main( + batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False +): + assert heads % groups == 0, "heads must be divisible by groups" + + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim + total_flops = 2 * flops_per_matmul + + tilelang.testing.set_random_seed(0) + + if is_causal: + total_flops *= 0.5 + + tilelang.testing.set_random_seed(0) + + dtype = torch.float16 + device = torch.device("cuda") + + head_kv = heads // groups + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + UQ = q_unpad.shape[0] + UKV = k_unpad.shape[0] + + kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + out = output_pad_fn(out_unpad) + + out_ref, _ = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + causal=is_causal, + ) + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") + args = parser.parse_args() + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal) diff --git a/tilelang/original/examples/flash_attention/example_mha_bwd_bhsd.py b/tilelang/original/examples/flash_attention/example_mha_bwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..835a315965db00752f4096a60a2de4a4db10bf68 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_bwd_bhsd.py @@ -0,0 +1,363 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + # Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + # T.copy(Q_shared, Q_local) + # for i, j in T.Parallel(block_M, dim): + # Q_local[i, j] *= scale + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) + + return flash_bwd + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal): + BATCH, H, N_CTX, D_HEAD = q.shape + block_M = 64 + block_N = 64 if D_HEAD <= 128 else 32 + o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, H, N_CTX, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + shape = [BATCH, H, N_CTX, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + return dq, dk, dv, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(2) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + BATCH: int = 8, + H: int = 32, + N_CTX: int = 1024, + D_HEAD: int = 64, + causal: bool = False, +): + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 5 * flops_per_matmul + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() + K = torch.empty_like(Q).normal_().requires_grad_() + V = torch.empty_like(Q).normal_().requires_grad_() + dO = torch.randn_like(Q) + O = attention(Q, K, V, causal) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + + print("All checks passed.✅") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/tilelang/original/examples/flash_attention/example_mha_bwd_bshd.py b/tilelang/original/examples/flash_attention/example_mha_bwd_bshd.py new file mode 100644 index 0000000000000000000000000000000000000000..c0620bde0e95480d907fe94d9909ee3c30348860 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_bwd_bshd.py @@ -0,0 +1,354 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + # Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) + + return flash_bwd + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal): + BATCH, N_CTX, H, D_HEAD = q.shape + block_M = 64 + block_N = 64 if D_HEAD <= 128 else 32 + o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + return dq, dk, dv, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 8, + H: int = 32, + N_CTX: int = 1024, + D_HEAD: int = 64, + causal: bool = False, +): + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 5 * flops_per_matmul + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() + K = torch.empty_like(Q).normal_().requires_grad_() + V = torch.empty_like(Q).normal_().requires_grad_() + dO = torch.randn_like(Q) + O = attention(Q, K, V, causal) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/tilelang/original/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/tilelang/original/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..34a8d69ce475f7126e2635ae91f9a71159aa0d91 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -0,0 +1,331 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + dq_shared = T.alloc_shared([block_N, dim], accum_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.wait_wgmma(1) + + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.wait_wgmma(0) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) + T.wait_wgmma(0) + T.copy(dq, dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) + + return flash_bwd + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal): + BATCH, N_CTX, H, D_HEAD = q.shape + block_M = 64 + block_N = 64 if D_HEAD <= 128 else 32 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 128 if D_HEAD <= 64 else 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + delta = mod_prep(o, do) + mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + mod(q, k, v, do, lse, delta, dq, dk, dv) + dq = dq.to(torch.float16) + return dq, dk, dv, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 8, + H: int = 32, + N_CTX: int = 1024, + D_HEAD: int = 64, + causal: bool = False, +): + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 5 * flops_per_matmul + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() + K = torch.empty_like(Q).normal_().requires_grad_() + V = torch.empty_like(Q).normal_().requires_grad_() + dO = torch.randn_like(Q) + O = attention(Q, K, V, causal) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/tilelang/original/examples/flash_attention/example_mha_fwd_bhsd.py b/tilelang/original/examples/flash_attention/example_mha_fwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..e70d17bf8c9adc9b12d90111e3fe8906cccc5ba0 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_fwd_bhsd.py @@ -0,0 +1,220 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=1, help="heads") + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/tilelang/original/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/tilelang/original/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c4d81ece8607c4499798d58421c067dc60b518 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,224 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/tilelang/original/examples/flash_attention/example_mha_fwd_bshd.py b/tilelang/original/examples/flash_attention/example_mha_fwd_bshd.py new file mode 100644 index 0000000000000000000000000000000000000000..248073f797d7b41d6b223f6de25c02f5b486de5b --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_fwd_bshd.py @@ -0,0 +1,205 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/tilelang/original/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/tilelang/original/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2aab44f496f7ae422d17f5d2a1baf0c057116e --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -0,0 +1,211 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + ref_program_processed = partial(ref_program, is_causal=is_causal) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/tilelang/original/examples/flash_attention/example_mha_fwd_varlen.py b/tilelang/original/examples/flash_attention/example_mha_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba2e8ab472d371cedcf1fe2d7203e25c85ec099 --- /dev/null +++ b/tilelang/original/examples/flash_attention/example_mha_fwd_varlen.py @@ -0,0 +1,288 @@ +# ruff: noqa +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +import argparse + +import torch +from einops import rearrange, repeat +from varlen_utils import generate_random_padding_mask, generate_qkv + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + dim = q.shape[-1] + scale = (1.0 / dim) ** 0.5 # log2(e) + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + scores = torch.einsum("bthd,bshd->bhts", q, k) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) + scores = scores * scale + attention = torch.softmax(scores, dim=-1).to(v.dtype) + + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + output = torch.einsum("bhts,bshd->bthd", attention, v) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [UQ, heads, dim] + k_shape = [UKV, heads, dim] + v_shape = [UKV, heads, dim] + o_shape = [UQ, heads, dim] + + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(k_shape, dtype), + V_unpad: T.Tensor(v_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") + K_shared = T.alloc_shared([block_N, dim], dtype, "shared") + V_shared = T.alloc_shared([block_N, dim], dtype, "shared") + O_shared = T.alloc_shared([block_M, dim], dtype, "shared") + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + batch_idx = bz + head_idx = by + + q_start_idx = cu_seqlens_q[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + v_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + v_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + v_current_seqlen = v_end_idx - v_start_idx + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d] + else: + Q_shared[i, d] = 0 + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(k_current_seqlen, block_N) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + # Q * K + for i, d in T.Parallel(block_N, dim): + if k * block_N + i < k_current_seqlen: + K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d] + else: + K_shared[i, d] = 0 + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0 + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + for i, d in T.grid(block_N, dim): + if k * block_N + i < v_current_seqlen: + V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d] + else: + V_shared[i, d] = 0 + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + return main + + +def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + + tilelang.testing.set_random_seed(0) + + causal = False + if causal: + total_flops *= 0.5 + + dtype = torch.float16 + device = torch.device("cuda") + window_size = (-1, -1) + + q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + + query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + UQ = q_unpad.shape[0] # unpadded query length + UK = k_unpad.shape[0] # unpadded key length + UKV = k_unpad.shape[0] # unpadded query key length + + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + out = output_pad_fn(out_unpad) + + out_ref, _ = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + causal=causal, + ) + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + + import flash_attn + + fla_out_unpad = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + ) + fla_out = output_pad_fn(fla_out_unpad) + torch.testing.assert_close(out, fla_out, rtol=1e-2, atol=1e-2) + + print("All checks passed.✅") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim) diff --git a/tilelang/original/examples/flash_attention/test_example_flash_attention.py b/tilelang/original/examples/flash_attention/test_example_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..da172bb62a4dee0ada293fc1249812905ded8de1 --- /dev/null +++ b/tilelang/original/examples/flash_attention/test_example_flash_attention.py @@ -0,0 +1,101 @@ +import tilelang.testing + +import example_gqa_bwd +import example_gqa_bwd_wgmma_pipelined +import example_mha_bwd_bshd +import example_mha_bwd_bhsd +import example_mha_fwd_bhsd_wgmma_pipelined +import example_gqa_fwd_bshd +import example_mha_fwd_bshd +import example_gqa_fwd_bshd_wgmma_pipelined +import example_mha_fwd_bshd_wgmma_pipelined +import example_mha_fwd_varlen +import example_mha_bwd_bshd_wgmma_pipelined +import example_mha_fwd_bhsd +import example_gqa_bwd_tma_reduce_varlen + + +@tilelang.testing.requires_cuda +def test_example_gqa_bwd_tma_reduce_varlen(): + example_gqa_bwd_tma_reduce_varlen.main() + + +@tilelang.testing.requires_cuda +def test_example_gqa_bwd(): + example_gqa_bwd.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_bwd_wgmma_pipelined(): + example_gqa_bwd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_bwd(): + example_mha_bwd_bshd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) + + +@tilelang.testing.requires_cuda +def test_example_mha_bwd_bhsd(): + example_mha_bwd_bhsd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_bwd_wgmma_pipelined(): + example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_fwd_bshd_wgmma_pipelined(): + example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + + +@tilelang.testing.requires_cuda +def test_example_gqa_fwd_bshd(): + example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_fwd_bhsd_wgmma_pipelined(): + example_mha_fwd_bhsd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_fwd_bhsd(): + example_mha_fwd_bhsd.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_fwd_bshd_wgmma_pipelined(): + example_mha_fwd_bshd_wgmma_pipelined.main(batch=1, heads=32, seq_len=256) + + +@tilelang.testing.requires_cuda +def test_example_mha_fwd_bshd(): + example_mha_fwd_bshd.main(batch=1, seq_len=256) + + +@tilelang.testing.requires_cuda +def test_example_mha_fwd_varlen(): + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/flash_attention/varlen_utils.py b/tilelang/original/examples/flash_attention/varlen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43e21cc3b80ce72eaa582407024ec2c42015731e --- /dev/null +++ b/tilelang/original/examples/flash_attention/varlen_utils.py @@ -0,0 +1,108 @@ +# ruff: noqa +import torch +from einops import rearrange, repeat +from bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) diff --git a/tilelang/original/examples/flash_decoding/README.md b/tilelang/original/examples/flash_decoding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a1b416125dd31d705d50327940dd62ba3ee2a2a4 --- /dev/null +++ b/tilelang/original/examples/flash_decoding/README.md @@ -0,0 +1 @@ +# Flash Decoding diff --git a/tilelang/original/examples/flash_decoding/example_gqa_decode.py b/tilelang/original/examples/flash_decoding/example_gqa_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..ee42df2080f3f3ed468b522e9df0db25377144de --- /dev/null +++ b/tilelang/original/examples/flash_decoding/example_gqa_decode.py @@ -0,0 +1,495 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse +import itertools +from functools import lru_cache +from typing import Tuple, Dict + +torch.random.manual_seed(0) + + +def get_configs(): + block_N = [64, 128] + block_H = [64] + num_split = [1, 2, 4, 8] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +@lru_cache(maxsize=1) +def get_heuristic_config() -> Tuple[Dict, int]: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version == 89: + cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128) + else: + cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128) + return cfg, sm_version + + +# TODO(lei): fix warp specialized and tma lower pass +def get_pass_configs(): + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [batch, seqlen_kv, groups, dim] + shape_v = [batch, seqlen_kv, groups, dim] + shape_o = [batch, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // groups + + part_shape = [batch, heads, num_split, dim] + valid_block_H = min(block_H, kv_group_num) + valid_block_N = min(block_N, seqlen_kv // num_split) + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + mask_local = T.alloc_fragment([block_N], "uint8") + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + mask_local = T.alloc_fragment([block_N], "uint8") + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + K[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + K_shared, + ) + T.copy( + mask[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + ], + mask_local, + ) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy( + V[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + V_shared, + ) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 128], dtype) + lse_logsum_local = T.alloc_fragment([128], accum_dtype) + lse_max_local = T.alloc_fragment([128], accum_dtype) + scale_local = T.alloc_fragment([128], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), + # lse_local: (local_id, thread_id) + lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k, j in T.Parallel(num_split, 128): + lse_local[k, j] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.serial(num_split): + for j in T.Parallel(128): + lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j]) + for j in T.Parallel(128): + lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + for j in T.Parallel(128): + scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j]) + # Note: Pay attention to dim and the number of threads in Parallel + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[i] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def flashattn_gqa_decode_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, mask, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn(Q, K, V, mask, Output) + + if num_split > 1: + return flashattn_gqa_decode_split + else: + return flashattn_gqa_decode_no_split + + +def ref_program(query, key, value, mask, glse, Output_partial): + # """ + # Inputs: + # - query (Tensor): [batch, heads, dim] + # - key (Tensor): [batch, seqlen_kv, groups, dim] + # - value (Tensor): [batch, seqlen_kv, groups, dim] + # - mask (Tensor): [batch, seqlen_kv, groups] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = query.shape[-1] + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + if mask is not None: + mask = rearrange(mask, "b s h -> b h s") + mask = mask.unsqueeze(1) + scores = scores.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def flash_split_ref(Q, K, V, mask): + num_split = 16 + batch = Q.size(0) + nheads = Q.size(1) + groups = K.size(2) + dim = Q.size(-1) + block_N = 32 + seqlen_kv = K.size(1) + num_head_groups = nheads // groups + + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) + + Q_ = Q * scale + Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups) + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum( + "bghd,bkhd->bghk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] + if mask is not None: + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :] + mask_local = rearrange(mask_local, "b s h -> b h s") + mask_local = mask_local.unsqueeze(1) + acc_s = acc_s.masked_fill(mask_local == 0, float("-inf")) + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] + scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] + acc_o *= scores_scale[:, :, :, None] + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) + acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] + acc_o += torch.einsum( + "bghk,bkhd->bghd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o_out = rearrange(acc_o, "b g h d->b (h g) d") + logsum_out = rearrange(logsum, "b g h->b (h g)") + acc_o_out /= logsum_out[:, :, None] + logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)") + gacc_o[ks, :, :, :] = acc_o_out + glogsum[ks, :, :] = logsum_out + + return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3) + + +def reduce_ref(Q, K, V, mask, glse, Output_partial): + num_split = 16 + o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) # [batch, heads] + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks] + scale = torch.exp2(lse - lse_logsum) # [batch, heads] + o += Output_partial[:, :, ks, :] * scale[:, :, None] + return o.to(torch.float16) + + +def ref_split_program(Q, K, V, mask, glse=None, Output_partial=None): + glse_, Output_partial_ = flash_split_ref(Q, K, V, mask) + return reduce_ref(Q, K, V, mask, glse_, Output_partial_) + + +def print_red_warning(msg): + print(f"\033[91m{msg}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): + sim = calc_sim(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff}") + if assert_: + raise AssertionError(f"{name} Error: {diff}") + else: + if print_: + print(f"passed: {name} diff={diff}") + + +def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False): + batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim + qk_flops = 2 * batch * heads * kv_seqlen * dim + pv_flops = 2 * batch * heads * kv_seqlen * dim + total_flops = qk_flops + pv_flops + + if not tune: + config, sm_version = get_heuristic_config() + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + + q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16) + k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) + v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) + mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8) + split = config["num_split"] + glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16) + Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16) + o = kernel(q, k, v, mask, glse, Output_partial) + o_ref = ref_program(q, k, v, mask, glse, Output_partial) + o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial) + + print(o) + print(o_ref) + + assert_similar(o, o_ref, name="o_ref") + assert_similar(o, o_ref_split, name="o_ref_split") + + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, groups, kv_seqlen, dim) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) diff --git a/tilelang/original/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/tilelang/original/examples/flash_decoding/example_gqa_decode_varlen_logits.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3d8baed6fce23f474597d29aeb365e611e629e --- /dev/null +++ b/tilelang/original/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -0,0 +1,909 @@ +import torch +import triton +import triton.language as tl +import math +import argparse +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +torch.manual_seed(0) +tilelang.disable_cache() + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +@triton.jit +def _fwd_inner( + q, + k_ptrs, + v_ptrs, + s_ptrs, + m_i, + l_i, + acc, + offs_h, + mask_h, + offs_n, + seqlen, + softmax_scale, + lo, + hi, + stride_kt, + stride_vt, + stride_sh, + stride_sn, + BLOCK_N: tl.constexpr, +): + """Inner loop computation for attention""" + + for blk_idx in tl.range(lo, hi): + start_n = blk_idx * BLOCK_N + k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen) + v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen) + + qk = tl.dot(q, k) + qk *= softmax_scale + qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9) + + row_max = tl.max(qk, 1) + tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h) + + m_ij = tl.maximum(m_i, row_max) + qk -= m_ij[:, None] + p = tl.math.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + acc *= alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + + return m_i, l_i, acc + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], + key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], +) +@triton.jit +def _fwd_kernel_varlen( + Q, # [token_q = b, h_q, dim] + K, # [token_k, h_kv, dim] + V, + O, + S, + s_aux, + softmax_scale, + cu_seqlens_k, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_sb, + stride_sh, + stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] + gqa_group_size: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(0) + off_h_for_kv = tl.program_id(1) + off_h_q = off_h_for_kv * gqa_group_size + + cu_k_start = tl.load(cu_seqlens_k + off_z) + cu_k_end = tl.load(cu_seqlens_k + off_z + 1) + + seqlen_k = cu_k_end - cu_k_start + + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh + K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh + V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh + O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh + S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh + + mask_h = offs_h < gqa_group_size + q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) + + if s_aux is not None: + sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) + l_i = tl.zeros([BLOCK_H], dtype=tl.float32) + m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink + else: + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N) + m_i, l_i, acc = _fwd_inner( + q, + k_ptrs, + v_ptrs, + S_ptrs, + m_i, + l_i, + acc, + offs_h, + mask_h, + offs_n, + seqlen_k, + softmax_scale, + lo, + hi, + stride_kt, + stride_vt, + stride_sh, + stride_sn, + BLOCK_N, + ) + + if s_aux is not None: + sink = tl.math.exp(sink - m_i) + l_i = l_i + sink + acc = acc / l_i[:, None] + + else: + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + for blk_idx in tl.range(lo, hi): + s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h) + s = tl.exp(s - m_i) / l_i + tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h) + + acc = acc.to(O.dtype.element_ty) + + tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) + + +def get_configs(): + import itertools + + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") +def flashattn( + batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // k_heads + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + + T.annotate_layout( + { + # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + # K_shared: tilelang.layout.make_swizzled_layout(K_shared), + # V_shared: tilelang.layout.make_swizzled_layout(V_shared), + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + # S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], + # -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + # T.copy(S_shared, S_fragment) + # for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + # S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + # T.copy(S_fragment, S_shared) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def flash_attn_with_attn_pool_decode( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + BLOCK_D = head_size + BLOCK_N = block_size + BLOCK_H = 64 + + O = torch.zeros_like(Q) + S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) + + def grid(META): + return (batch, k_h) + + with torch.cuda.device(Q.device.index): + _fwd_kernel_varlen[grid]( + Q, + K, + V, + O, + S, + s_aux, + softmax_scale, + cu_seqlens_k, + *Q.stride(), + *K.stride(), + *V.stride(), + *O.stride(), + *S.stride(), + gqa_group_size, + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + ) + + if use_per_kv_head_sparse_index: + S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O, S + + +def test_equal_seqlen_decode_main(args): + """Test decode kernel with equal sequence lengths""" + print("Testing decode kernel with equal sequence lengths") + + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + # For decode, query is just 1 token per batch + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + softmax_scale = 1.0 / math.sqrt(head_size) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Convert to varlen format for K, V + k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) + v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) + + # Generate cumulative sequence lengths + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) + max_seqlen_k = k_seqlen + + print(f"q shape: {q.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Compute torch reference + q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] + k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + + if sink is None: + # Standard scaled dot-product attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + attn_weights = torch.softmax(logits, dim=-1) + O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] + + # Compute attention score pooling + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, k_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(torch.float16) + + print("S_tilelang", S_tilelang) + print("attn_score_pooled", attn_score_pooled) + + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) + max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) + + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + print("✅ All tests passed!") + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float("-inf") + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float("-inf") + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) + + print("✅ All tests passed!") + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Benchmark + print("⚡ Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("⚡ Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + args = parser.parse_args() + args.test_sink = True + args.test_varlen = False + args.dtype = T.float16 + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + elif args.test_varlen: + test_varlen_decode_main(args) + else: + test_equal_seqlen_decode_main(args) diff --git a/tilelang/original/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/tilelang/original/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..0984e707531fc49e3f7c2130b1299b9826c4ea53 --- /dev/null +++ b/tilelang/original/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -0,0 +1,679 @@ +import torch +import math +import argparse +import tilelang +import tilelang.language as T +from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench + +torch.manual_seed(0) + + +def get_configs(): + import itertools + + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +# @autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") +def flashattn( + batch, + heads, + k_heads, + max_seqlen_kv, + total_seqlen_k, + dim, + has_sink, + page_block_size, + block_N=128, + block_H=64, + num_split=1, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // k_heads + assert page_block_size >= block_N and page_block_size % block_N == 0, ( + "page_block_size must be larger than block_N and a multiple of block_N" + ) + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], T.int32), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, + block_table: torch.Tensor = None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def test_equal_seqlen_decode_main(args): + """Test decode kernel with equal sequence lengths""" + print("Testing decode kernel with equal sequence lengths") + + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + # For decode, query is just 1 token per batch + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + softmax_scale = 1.0 / math.sqrt(head_size) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Convert to varlen format for K, V + k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() + v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() + + # Generate cumulative sequence lengths + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) + max_seqlen_k = k_seqlen + + print(f"q shape: {q.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Compute torch reference + q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] + k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + + if sink is None: + # Standard scaled dot-product attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + attn_weights = torch.softmax(logits, dim=-1) + O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] + + # Compute attention score pooling + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, k_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(torch.float16) + + print("S_tilelang", S_tilelang) + print("attn_score_pooled", attn_score_pooled) + + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) + max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) + + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + print("✅ All tests passed!") + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float("-inf") + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float("-inf") + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) + + print("✅ All tests passed!") + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Benchmark + print("⚡ Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + block_table, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("⚡ Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + parser.add_argument("--page_block_size", type=int, default=128, help="Page block size") + args = parser.parse_args() + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + elif args.test_varlen: + test_varlen_decode_main(args) + else: + test_equal_seqlen_decode_main(args) diff --git a/tilelang/original/examples/flash_decoding/example_mha_inference.py b/tilelang/original/examples/flash_decoding/example_mha_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5b243d695eefd8bb2d0fa0f18e30986d96ca7135 --- /dev/null +++ b/tilelang/original/examples/flash_decoding/example_mha_inference.py @@ -0,0 +1,322 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from functools import partial + +num_split = 4 + + +@tilelang.jit(out_idx=[5]) +def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, seqlen_q, heads, dim] + shape_kv = [batch, seqlen_kv, heads, dim] + part_shape = [batch, seqlen_q, heads, num_split, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(shape_kv, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + mid: T.int32, + hid: T.int32, + bid: T.int32, + sid: T.int32, + ): + T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared) + # TODO: Handle causal split case + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape_kv, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + hid: T.int32, + bid: T.int32, + sid: T.int32, + ): + T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), + ): + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + mid = bx + hid = by % heads + bid = by // heads + sid = bz + + # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently + # disable relevant tma copy and use SIMT as fallback for now + T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # TODO: Handle causal split case + loop_range = ( + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) + if is_causal + else T.ceildiv((seqlen_kv // num_split), block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=2): + MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_q, dtype), + ): + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): + po_local = T.alloc_fragment([block_M, dim], dtype) + po_shared = T.alloc_shared([block_M, dim], dtype) + o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype) + o_shared = T.alloc_shared([block_M, dim], dtype) + lse_local = T.alloc_fragment([num_split, block_M], dtype) + lse_local_split = T.alloc_fragment([block_M], accum_dtype) + lse_logsum_local = T.alloc_fragment([block_M], accum_dtype) + lse_max_local = T.alloc_fragment([block_M], accum_dtype) + scale_local = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout( + { + o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), + o_shared: tilelang.layout.make_swizzled_layout(o_shared), + po_shared: tilelang.layout.make_swizzled_layout(po_shared), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + T.copy( + glse[ + bz, + by, + :, + bx * block_M : (bx + 1) * block_M, + ], + lse_local, + ) + T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) + for k in T.Pipelined(num_split): + T.copy(lse_local[k, :], lse_local_split) + for i in T.Parallel(block_M): + lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i]) + for i in T.Parallel(block_M): + lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] + for k in T.Pipelined(num_split, num_stages=2): + T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True) + T.copy(po_shared, po_local) + for i in T.Parallel(block_M): + lse_local_split[i] = lse_local[k, i] + for i in T.Parallel(block_M): + scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i]) + for i, j in T.Parallel(block_M, dim): + o_accum_local[i, j] += po_local[i, j] * scale_local[i] + T.copy(o_accum_local, o_shared) + T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True) + + @T.prim_func + def flashattn_mha_inference( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] + Output: T.Tensor(shape_q, dtype), + ): + flash_attn_split(Q, K, V, glse, Output_partial) + combine(glse, Output_partial, Output) + + return flashattn_mha_inference + + +def ref_program(Q, K, V, glse, Output_partial, causal): + assert causal is False + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def reduce_ref(Q, K, V, glse, Output_partial, causal): + o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads] + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks, :] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks, :] + scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q] + o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2) + return o.to(torch.float16) + + +def flash_split_ref(Q, K, V, causal): + # [batch, seqlen_q, heads, dim] + batch = Q.size(0) + block_M = Q.size(1) + nheads = Q.size(2) + dim = Q.size(3) + block_N = 128 + seqlen_kv = K.size(1) + + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float) + + Q_ = Q * scale + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum( + "bqhd,bkhd->bhqk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, seqlen, nheads, block_N] + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] + scores_scale = torch.exp2(scores_max_prev - scores_max) + acc_o *= scores_scale[:, :, :, None].transpose(1, 2) + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) + acc_s_cast = acc_s.to(torch.float16) + acc_o += torch.einsum( + "bhqk,bkhd->bqhd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, :, None].transpose(1, 2) + logsum = torch.log2(logsum) + scores_max + gacc_o[ks, :, :, :, :] = acc_o + glogsum[ks, :, :, :] = logsum + + return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) + + +def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): + flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + BLOCK_M = 128 + BLOCK_N = 64 # if D_HEAD <= 128 else 32 + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + ref_fn = partial(ref_program, causal=causal) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01) + print("All checks passed!") + + latency = profiler.do_bench(ref_fn, warmup=500) + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(n_warmup=10, n_repeat=10) + print("{:.4f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/flash_decoding/test_example_flash_decoding.py b/tilelang/original/examples/flash_decoding/test_example_flash_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..c728dfe0e1f14712ba29363ca380fa425bd9d536 --- /dev/null +++ b/tilelang/original/examples/flash_decoding/test_example_flash_decoding.py @@ -0,0 +1,19 @@ +import tilelang.testing + +import example_gqa_decode +import example_mha_inference + + +# TODO(lei): fix the correctness of gqa decode on sm90 +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_example_gqa_decode(): + example_gqa_decode.main() + + +def test_example_example_mha_inference(): + example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/fusedmoe/example_fusedmoe_tilelang.py b/tilelang/original/examples/fusedmoe/example_fusedmoe_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..36c6ef3dc20004e9ac0076d2f6bc7680e5c33371 --- /dev/null +++ b/tilelang/original/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -0,0 +1,524 @@ +import math +import torch +import torch.nn as nn +from typing import Dict, Tuple, Optional +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from example_fusedmoe_torch import * + + +@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def moe_forward_tilelang_shared( + d_hidden, + d_expert, + n_shared_experts, + dtype, + num_tokens, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): + scale = 1.44269504 # log2(e) + + # Parameters + dhidden = d_hidden + dexpert = d_expert * n_shared_experts + + # Tensors: Note that input shape is reshape to (num_tokens, dhidden) + input_shape = (num_tokens, dhidden) + shared_W_gate_shape = (dexpert, dhidden) + shared_W_up_shape = (dexpert, dhidden) + shared_W_down_shape = (dhidden, dexpert) + + accum_type = T.float32 + + @T.prim_func + def kernel_shared( + input: T.Tensor(input_shape, dtype), # type: ignore + shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore + shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore + shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore + up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore + ): + # Step 1: Compute gate and up logits + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): + # Split the block to shared experts and routed experts + input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) + W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + W_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + # Shared experts: no need to check expert_indices + + gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type) + up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type) + + T.use_swizzle(10) + T.clear(gate_logits_local) + T.clear(up_logits_local) + + # Parallel for gate and up matmul + for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): + T.copy(input[bx * block_token, k * block_dhidden], input_shared) + T.copy(shared_W_gate[by * block_dexpert, k * block_dhidden], W_gate_shared) + T.copy(shared_W_up[by * block_dexpert, k * block_dhidden], W_up_shared) + T.gemm(input_shared, W_gate_shared, gate_logits_local, transpose_B=True) + T.gemm(input_shared, W_up_shared, up_logits_local, transpose_B=True) + + # Fuse with SiLU and element-wise product + for i, j in T.Parallel(block_token, block_dexpert): + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] + + T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) + + # Step 2: Compute down logits + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): + up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) + W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) + output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) + + T.use_swizzle(10) + T.clear(output_local) + + for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): + T.copy(up_logits[bx * block_token, k * block_dexpert], up_logits_shared) + T.copy(shared_W_down[by * block_dhidden, k * block_dexpert], W_down_shared) + T.gemm(up_logits_shared, W_down_shared, output_local, transpose_B=True) + + T.copy(output_local, output[bx * block_token, by * block_dhidden]) + + return kernel_shared + + +@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def moe_forward_tilelang_routed( + d_hidden, + d_expert, + n_routed_experts, + dtype, + group_sum, + group_count, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=None, +): + scale = 1.44269504 # log2(e) + + # Parameters + dhidden = d_hidden + dexpert = d_expert + n_routed_experts = n_routed_experts + + # Group info + # group_sum = sum(group_sizes_list) + # group_count = len(group_sizes_list) + # M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list]) + M = math.ceil(group_sum / block_token) + group_count + accum_dtype = T.float32 + + # Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm + input_shape = (group_sum, dhidden) + intermediate_shape = (group_sum, dexpert) + routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) + routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) + routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) + routed_expert_weights_shape = group_sum + group_sizes_shape = n_routed_experts + + @T.prim_func + def kernel( + input: T.Tensor(input_shape, dtype), # type: ignore + routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore + routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore + routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore + routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore + up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore + ): + # Step 1: Compute gate and up logits + with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): + input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) + routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + + gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) + up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) + + cur_group_idx = T.alloc_local([1], T.int32) + cur_group_size = T.alloc_local([1], T.int32) + + T.use_swizzle(10, enable=True) + + m_start_padded = bx * block_token + + cur_group_idx[0] = group_idx_for_bx[bx] + + cur_group_size[0] = group_sizes[cur_group_idx[0]] + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + + T.clear(gate_logits_local) + T.clear(up_logits_local) + + for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): + T.copy( + input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], + input_shared, + coalesced_width=coalesced_width, + ) + T.copy( + routed_expert_gate[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], + routed_expert_gate_shared, + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) + T.copy( + routed_expert_up[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], + routed_expert_up_shared, + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True) + + for i, j in T.Parallel(block_token, block_dexpert): + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] + + for i, j in T.Parallel(block_token, block_dexpert): + if i < actual_rows: + up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j] + + # Step 2: Compute down logits + with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): + up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) + routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) + output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) + + cur_group_idx = T.alloc_local([1], T.int32) + cur_group_size = T.alloc_local([1], T.int32) + + T.use_swizzle(10, enable=True) + + m_start_padded = bx * block_token + + cur_group_idx[0] = group_idx_for_bx[bx] + + cur_group_size[0] = group_sizes[cur_group_idx[0]] + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + + T.clear(output_local) + + for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): + T.copy( + up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], + up_logits_shared, + coalesced_width=coalesced_width, + ) + T.copy( + routed_expert_down[ + cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + ], + routed_expert_down_shared, + coalesced_width=coalesced_width, + ) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True) + + for i, j in T.Parallel(block_token, block_dhidden): + if i < actual_rows: + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] + + return kernel + + +class Expert(nn.Module): + def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None): + super().__init__() + self.config = config + self.act_fn = nn.SiLU() + self.d_hidden: int = config["d_hidden"] + self.d_expert: int = config["d_expert"] if d_expert is None else d_expert + self.device = torch.device("cuda") + + self.W_gate_weight = gate.t().contiguous().to(self.device) + self.W_up_weight = up.t().contiguous().to(self.device) + self.W_down_weight = down.t().contiguous().to(self.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = self.act_fn(x @ self.W_gate_weight) + out = (gate * (x @ self.W_up_weight)) @ self.W_down_weight + return out + + +class MoEGate(nn.Module): + def __init__(self, config: Dict, weights: Dict): + super().__init__() + self.top_k: int = config["n_experts_per_token"] + self.num_experts: int = config["n_routed_experts"] + self.d_hidden: int = config["d_hidden"] + + self.W_g_weight = weights["router.weight"].t() + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + logits = x @ self.W_g_weight + scores = logits.softmax(dim=-1) + topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + return topk_indices, topk_scores + + +class MoE(nn.Module): + def __init__( + self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128 + ): + super().__init__() + self.config = config + self.shared_kernel = shared_kernel + self.routed_kernel = routed_kernel + self.padding_M = padding_M + self.experts = nn.ModuleList( + [ + Expert( + config, + gate=weights[f"experts.{i}.0.weight"], + up=weights[f"experts.{i}.1.weight"], + down=weights[f"experts.{i}.2.weight"], + ) + for i in range(config["n_routed_experts"]) + ] + ) + self.device = torch.device("cuda") + self.gating_network = MoEGate(config, weights).to(self.device) + shared_expert_dim = config["d_expert"] * config["n_shared_experts"] + self.shared_expert = Expert( + config=config, + gate=weights["shared_experts.0.weight"], + up=weights["shared_experts.1.weight"], + down=weights["shared_experts.2.weight"], + d_expert=shared_expert_dim, + ).to(self.device) + self.expert_cache = torch.zeros( + (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0) + self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0) + self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0) + self.stacked_expert_tokens = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), + dtype=torch.float16, + device=self.device, + ) + self.stacked_expert_weights = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_tokens_idxs = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device + ) + + self.up_logits_shared = torch.empty( + (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device + ) + self.expert_output_shared = torch.empty( + (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.up_logits_routed = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]), + dtype=torch.float16, + device=self.device, + ) + self.expert_output_routed = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), + dtype=torch.float16, + device=self.device, + ) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_shape = x.shape + batch_size, seq_len, hidden_dim = x.shape + expert_indices, expert_scores = self.gating_network(x) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1) + x_flat = x.view(-1, hidden_dim) + + # Prepare for grouped GEMM + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + # counts = flat_expert_indices.bincount() + tokens_per_expert = counts.cumsum() + # tokens_per_expert = torch.cumsum(counts, dim=0) + num_per_tok = self.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + + # Get stacked expert tokens and expert weights + + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x_flat[exp_token_idxs] + + self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens + self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs + self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] + + group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device) + + group_padded_offsets = [0 for _ in range(len(group_sizes))] + for i in range(1, len(group_sizes)): + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M + + block_token = 128 + M = ( + math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token) + + self.config["n_routed_experts"] + ) + group_idx_for_bx = [0 for _ in range(M)] + + for bx in range(M): + m_start_padded = bx * block_token + for i in range(self.config["n_routed_experts"]): + if m_start_padded >= group_padded_offsets[i]: + group_idx_for_bx[bx] = i + + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device) + group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) + + # Multi-stream execution + shared_stream = torch.cuda.Stream() + routed_stream = torch.cuda.default_stream() + torch.cuda.synchronize() + + with torch.cuda.stream(routed_stream): + # Tilelang version: Grouped GEMM + self.routed_kernel( + self.stacked_expert_tokens, + self.stacked_expert_w_gate, + self.stacked_expert_w_up, + self.stacked_expert_w_down, + self.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + self.up_logits_routed, + self.expert_output_routed, + ) + + # Scatter reduce + self.expert_cache = torch.scatter_reduce( + self.expert_cache, + 0, + self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), + self.expert_output_routed, + reduce="sum", + ) + routed_output = self.expert_cache.view(*orig_shape) + + with torch.cuda.stream(shared_stream): + self.shared_kernel( + x_flat, + self.shared_expert.W_gate_weight, + self.shared_expert.W_up_weight, + self.shared_expert.W_down_weight, + self.up_logits_shared, + self.expert_output_shared, + ) + shared_output = self.expert_output_shared.view(*orig_shape) + + torch.cuda.synchronize() + + return shared_output + routed_output + + +def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: + """ + DeepSeek-style Mixture of Experts using Tilelang. + + Args: + data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) + - input: Input tensor of shape [batch_size, seq_len, hidden_size] + - weights: Dictionary containing model weights + - config: Dictionary containing model configuration parameters + + Returns: + Tuple containing: + - output: Processed tensor [batch_size, seq_len, d_model] + """ + input_tensor, weights, config = data + + dtype_str = T.float16 + + shared_kernel = moe_forward_tilelang_shared( + config["d_hidden"], + config["d_expert"], + config["n_shared_experts"], + dtype=dtype_str, + num_tokens=config["batch_size"] * config["seq_len"], + ) + routed_kernel = moe_forward_tilelang_routed( + config["d_hidden"], + config["d_expert"], + config["n_routed_experts"], + dtype=dtype_str, + group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], + group_count=config["n_routed_experts"], + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=2, + ) + + moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) + + output = moe(input_tensor) + + return output + + +def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192): + config = { + "dhidden": d_hidden, + "dexpert": d_expert, + "nroutedexperts": n_routed_experts, + "nsharedexperts": n_shared_experts, + "nexpertspertoken": n_experts_per_token, + "bs": batch_size, + "seqlen": seq_len, + "seed": 81394, + } + + data = generate_input(**config) + + torch.cuda.synchronize() + ref_output = ref_kernel(clone_data(data)).to(torch.float32) + torch.cuda.synchronize() + tilelang_output = custom_kernel(clone_data(data)).to(torch.float32) + torch.cuda.synchronize() + + torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2) + print("✅ Tilelang and Torch match") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/fusedmoe/example_fusedmoe_torch.py b/tilelang/original/examples/fusedmoe/example_fusedmoe_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6322aff7dce196ce12f371d83f47e5c1fa82e4 --- /dev/null +++ b/tilelang/original/examples/fusedmoe/example_fusedmoe_torch.py @@ -0,0 +1,210 @@ +import math +import torch +import torch.nn as nn +from typing import Dict, Tuple, Optional + + +# Reference code in PyTorch +class ExpertTorch(nn.Module): + def __init__(self, config: Dict, d_expert: Optional[int] = None): + super().__init__() + self.config = config + self.act_fn = nn.SiLU() + self.d_hidden: int = config["d_hidden"] + self.d_expert: int = config["d_expert"] if d_expert is None else d_expert + + self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False) + self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False) + self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = self.act_fn(self.W_gate(x)) + out = self.W_down(gate * self.W_up(x)) + return out + + +class MoEGateTorch(nn.Module): + def __init__(self, config: Dict): + super().__init__() + self.top_k: int = config["n_experts_per_token"] + self.num_experts: int = config["n_routed_experts"] + self.d_hidden: int = config["d_hidden"] + + self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + logits = self.W_g(x) + scores = logits.softmax(dim=-1) + topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + return topk_indices, topk_scores + + +class MoETorch(nn.Module): + def __init__(self, config: Dict): + super().__init__() + self.config = config + self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])]) + self.gating_network = MoEGateTorch(config) + shared_expert_dim = config["d_expert"] * config["n_shared_experts"] + self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shared_output = self.shared_expert(x) + expert_indices, expert_scores = self.gating_network(x) + batch_size, seq_len, hidden_dim = x.shape + orig_shape = x.shape + x_flat = x.view(-1, hidden_dim) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1, 1) + routed_output_flat = self.moe_infer(x_flat, flat_expert_indices, flat_expert_weights) + + routed_output = routed_output_flat.view(*orig_shape) + return routed_output + shared_output + + @torch.no_grad() + def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor: + expert_cache = torch.zeros_like(x) + # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) + # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) + # test_expert_ups = torch.zeros((self.config["n_routed_experts"], self.config["d_hidden"], self.config["d_expert"])) + # test_expert_tokens_num = torch.zeros((self.config["n_routed_experts"])) + + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + num_per_tok = self.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + expert = self.experts[expert_id] + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idxs] + expert_out = expert(expert_tokens) + + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") + + return expert_cache + + +def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: + """ + Reference implementation of DeepSeek-style Mixture of Experts using PyTorch. + + Args: + data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) + - input: Input tensor of shape [batch_size, seq_len, hidden_dim] + - weights: Dictionary containing model weights + - config: Dictionary containing model configuration parameters + + Returns: + Tuple containing: + - output: Processed tensor [batch_size, seq_len, d_model] + """ + input_tensor, weights, config = data + num_experts = config["n_routed_experts"] + moe = MoETorch(config) + + # Fill in the given weights of the model + moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"]) + + for i in range(num_experts): + gate_proj_weight = weights[f"experts.{i}.0.weight"] + up_proj_weight = weights[f"experts.{i}.1.weight"] + down_proj_weight = weights[f"experts.{i}.2.weight"] + + # Transpose weights to match expected shape for nn.Linear + moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) + moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) + moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) + + moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t()) + moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t()) + moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t()) + + output = moe(input_tensor) + + return output + + +# Input generation for the reference code + + +def generate_input( + dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int +) -> Tuple[torch.Tensor, Dict, Dict]: + # Really dumb but for now _ isn't parsing correctly. + d_hidden = dhidden + d_expert = dexpert + n_routed_experts = nroutedexperts + n_shared_experts = nsharedexperts + n_experts_per_token = nexpertspertoken + batch_size = bs + seq_len = seqlen + + config = { + "d_hidden": d_hidden, + "d_expert": d_expert, + "n_routed_experts": n_routed_experts, + "n_shared_experts": n_shared_experts, + "n_experts_per_token": n_experts_per_token, + "batch_size": batch_size, + "seq_len": seq_len, + } + + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + + num_experts = n_routed_experts + expert_dim = d_expert + weights = {} + + input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous() + + # Initialize router weights + weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden) + + for i in range(num_experts): + weights[f"experts.{i}.0.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.1.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.2.weight"] = torch.randn( + (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + weights["shared_experts.0.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.1.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.2.weight"] = torch.randn( + (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + return (input_tensor, weights, config) + + +def clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(clone_data(x) for x in data) + elif isinstance(data, list): + return [clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data diff --git a/tilelang/original/examples/fusedmoe/test_example_fusedmoe.py b/tilelang/original/examples/fusedmoe/test_example_fusedmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8415895d52f75dc8bf029b7a97e1cabc983b03 --- /dev/null +++ b/tilelang/original/examples/fusedmoe/test_example_fusedmoe.py @@ -0,0 +1,12 @@ +import tilelang.testing +import example_fusedmoe_tilelang + + +def test_example_fusedmoe_tilelang(): + example_fusedmoe_tilelang.main( + d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024 + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gdn/README.md b/tilelang/original/examples/gdn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..31dd2361e125595c469fef4b44c5e1128d6e96e4 --- /dev/null +++ b/tilelang/original/examples/gdn/README.md @@ -0,0 +1,15 @@ +# Gated Delta Net (GDN) kernel implementation with TileLang + +## Requirement + +- TileLang: `0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1` +- Triton: `3.3.0` (used for comparison) +- FLA: commit `f03cb3ae` (used for comparison) + +## Get started + + The [chunk_delta_h](common/chunk_delta_h.py) implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the TileLang optimization. + +## Acknowledgments + +This kernel was developed by Yu Cheng and Zhengju Tang following in-depth discussions with Xiaomi's LLM-Core Team (MiMo). diff --git a/tilelang/original/examples/gdn/example_chunk_delta_bwd.py b/tilelang/original/examples/gdn/example_chunk_delta_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..39450bc5fc5917e1958687111ba684aabecf800d --- /dev/null +++ b/tilelang/original/examples/gdn/example_chunk_delta_bwd.py @@ -0,0 +1,613 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 + +import tilelang +import tilelang.language as T + +print(tilelang.__file__, flush=True) + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__, flush=True) + from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(0) +# torch.set_printoptions(profile="full") + +from test_utils import assert_similar + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + # Note: G should be in logspace and do chunkwise cumsum + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + try: + from fla.ops.utils.cumsum import chunk_local_cumsum + + G = chunk_local_cumsum(G, chunk_size) + except ImportError: + print("fla not found, skip cumsum") + + h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_input_fake( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + G = torch.ones(B, S, H, dtype=gate_dtype).cuda() + h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dh, dh0, dv2 + + +def torch_chunk_gated_delta_rule_bwd_dhu( + Q: torch.Tensor, + K: torch.Tensor, + W: torch.Tensor, + G: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + dO: torch.Tensor, + dv: torch.Tensor, + scale: float, + use_g: bool, + use_initial_state: bool, + use_final_state_gradient: bool, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + B, S, H, DK = Q.shape + DV = dv.shape[-1] + block_S = 64 + BS = S // block_S + dh, dh0, dv2 = ( + torch.empty((B, BS, H, DK, DV), dtype=output_dtype), + torch.empty((B, H, DK, DV), dtype=state_dtype), + torch.empty((B, S, H, DV), dtype=output_dtype), + ) + dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) + dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) + Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) + + if use_final_state_gradient: + dh_tmp = dht.clone().to(accum_dtype) + else: + dh_tmp = torch.zeros_like(dht).to(accum_dtype) + + for i_s in range(BS - 1, -1, -1): + dh[:, i_s, :, :, :] = dh_tmp + dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) + if use_g: + for i_bh in range(B * H): + i_b, i_h = i_bh // H, i_bh % H + for i_s2 in range(block_S): + if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0: + dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h]) + else: + dv_tmp[i_b, i_s2, i_h, :] = 0 + dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp + + if use_g: + G_last = G[:, i_s * block_S + block_S - 1, :] + for i_bh in range(B * H): + i_b, i_h = i_bh // H, i_bh % H + dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) + Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :] + for i_s2 in range(block_S): + for i_k in range(DK): + Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) + Q_tmp *= scale + W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :] + + torch.backends.cuda.matmul.allow_tf32 = True + dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) + dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3)) + torch.backends.cuda.matmul.allow_tf32 = False + + if use_initial_state: + dh0 = dh_tmp[:, :, :, :] + else: + dh0 = torch.zeros_like(dh_tmp[:, :, :, :]) + print(dh0.dtype) + + return dh, dh0, dv2 + + +@tilelang.jit(out_idx=[-3, -2, -1]) +def tilelang_chunk_gated_delta_rule_bwd_dhu( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_initial_state=True, + use_final_state_gradient=True, + # kernel config + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + # Should support cu_seqlen + BS = S // block_S + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + W_shape = (B, S, H, DK) + G_shape = (B, S, H) + h0_shape = (B, H, DK, DV) + dht_shape = (B, H, DK, DV) + dO_shape = (B, S, H, DV) + dv_shape = (B, S, H, DV) + + dh_shape = (B, BS, H, DK, DV) + dh0_shape = (B, H, DK, DV) + dv2_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) + b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype) + b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32) + dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32) + dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_last_local_exp = T.alloc_local((1), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared") + G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype) + G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype) + G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype) + Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype) + + T.use_swizzle(10) + + T.annotate_layout( + { + b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), + b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + } + ) + + if use_final_state_gradient: + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) + T.copy(b_dh_shared, b_dh_fragment) + else: + T.clear(b_dh_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # The gradient should be stored in the reverse order + i_s_inv = T.ceildiv(S, block_S) - i_s - 1 + + # Store the updated dh + T.copy(b_dh_fragment, b_dh_shared) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Update dv + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) + + if use_g: + T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) + T.copy(G_shared, G_fragment) + G_last_local[0] = G_shared[block_S - 1] + G_last_local_exp[0] = T.exp(G_last_local[0]) + for i_s2 in T.Parallel(block_S): + G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2]) + for i_s2, i_v in T.Parallel(block_S, block_DV): + # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): + with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): + with T.Then(): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] + with T.Else(): + dv_fragment[i_s2, i_v] = 0 + + T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) + T.copy(dv_shared, dv_fragment_2) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] + + # Store the updated dv + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + # Update dh + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) + + T.clear(Q_fragment) + if use_g: + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] *= G_last_local_exp[0] + T.copy(Q_shared, Q_fragment) + for i_s2 in T.Parallel(block_S): + G_fragment_exp[i_s2] = T.exp(G_shared[i_s2]) + for i_s2, i_k in T.Parallel(block_S, DK): + # Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale + Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale + else: + T.copy(Q_shared, Q_fragment) + for i_s2, i_k in T.Parallel(block_S, DK): + Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale + # Get transpose of Q_fragment to meet tf32 gemm requirement + for i_s2, i_k in T.Parallel(block_S, DK): + Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] + + T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared) + T.copy(dO_shared, dO_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] + T.copy(dO_fragment_t, dO_shared_t) + + T.clear(b_dh_fragment_1) + T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True) + T.clear(b_dh_fragment_2) + T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True) + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] + + if use_initial_state: + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name): + try: + torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True) + print(f"{name} dh_0 and dh_1 passed for {name}") + except Exception as e: + print(f"{name} dh_0 and dh_1 are not close for {name}") + print(e, end="\n\n") + try: + torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True) + print(f"{name} dh0_0 and dh0_1 passed for {name}") + except Exception as e: + print(f"{name} dh0_0 and dh0_1 are not close for {name}") + print(e, end="\n\n") + try: + torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True) + print(f"{name} dv2_0 and dv2_1 passed for {name}") + except Exception as e: + print(f"{name} dv2_0 and dv2_1 are not close for {name}") + print(e, end="\n\n") + + close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2) + mismatch_indices = torch.nonzero(~close, as_tuple=True) + error_num = 0 + for indices in zip(*mismatch_indices): + if error_num < 100: + print( + f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}" + ) + error_num += 1 + close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2) + mismatch_indices = torch.nonzero(~close, as_tuple=True) + error_num = 0 + for indices in zip(*mismatch_indices): + if error_num < 100: + print( + f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}" + ) + error_num += 1 + close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2) + mismatch_indices = torch.nonzero(~close, as_tuple=True) + error_num = 0 + for indices in zip(*mismatch_indices): + if error_num < 100: + print( + f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}" + ) + error_num += 1 + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=64, + threads=256, + num_stages=0, + use_torch=False, +): + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref, dh0_ref, dv2_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + + # fla ref + print("fla running...", flush=True) + if use_g: + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) + else: + G = G.fill_(0) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) + + # tilelang + print("tilelang running...", flush=True) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) + # kernel = tilelang.compile(program) + print(kernel.get_kernel_source()) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) + + fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) + tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) + + print(f"fla time: {fla_time} ms") + print(f"tilelang time: {tilelang_time} ms") + + assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh") + assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0") + assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2") + + # torch ref + if use_torch: + print("torch running...", flush=True) + if use_g: + dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( + Q, + K, + W, + G, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref_torch = dh_ref_torch.cuda() + dh0_ref_torch = dh0_ref_torch.cuda() + dv2_ref_torch = dv2_ref_torch.cuda() + else: + dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( + Q, + K, + W, + None, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref_torch = dh_ref_torch.cuda() + dh0_ref_torch = dh0_ref_torch.cuda() + dv2_ref_torch = dv2_ref_torch.cuda() + + assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh") + assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0") + assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2") + assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh") + assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0") + assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2") + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def main(): + DK = 128 + run_test( + B=1, + S=32768, + H=8, + DK=DK, + DV=128, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + scale=DK**-0.5, + use_g=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=32, + threads=128, + num_stages=1, + use_torch=False, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_chunk_delta_h.py b/tilelang/original/examples/gdn/example_chunk_delta_h.py new file mode 100644 index 0000000000000000000000000000000000000000..d316a62116c8b2fda7ca24daddc838e0cda94144 --- /dev/null +++ b/tilelang/original/examples/gdn/example_chunk_delta_h.py @@ -0,0 +1,408 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +import torch.nn.functional as F +from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 + +from test_utils import assert_similar + +# (zhengju) We can slightly modify the generated cuda code from tilelang lowering +# in the debug folder to make the performance better. To enable this callback, +# you can comment out the following function. +# @register_cuda_postproc_callback +# def tilelang_callback_cuda_postproc(code, _): +# cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read() +# code = cuda_code +# return code + +torch.random.manual_seed(0) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + W = F.normalize(W, dim=-1, p=2) + U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + U = F.normalize(U, dim=-1, p=2) + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + try: + from fla.ops.utils.cumsum import chunk_local_cumsum + + G = chunk_local_cumsum(G, chunk_size) + except ImportError: + print("fla not found, skip cumsum") + + initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + return K, W, U, G, initial_state + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + state_dtype, +): + BS = S // chunk_size + h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return h, final_state, V_new + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_gated_delta_rule_fwd_h( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + # kernel config + block_DK=64, + block_DV=32, + threads=128, + num_stages=1, +): + block_S = chunk_size + BS = S // block_S + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + U_shape = (B, S, H, DV) + G_shape = (B, S, H) + h_shape = (B, BS, H, DK, DV) + initial_state_shape = (B, H, DK, DV) + final_state_shape = (B, H, DK, DV) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype) + b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + + U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) + G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) + + T.annotate_layout( + { + b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + G_shared: tilelang.layout.make_swizzled_layout(G_shared), + } + ) + + T.use_swizzle(10) + + if use_initial_state: + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) + T.copy(b_h_shared, b_h_fragment) + else: + T.clear(b_h_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # Store previous result to the hidden tensor, like the epilogue + T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Recurrence + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared) + T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) + + # U - W * S + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) + T.copy(U_shared, U_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] + + # Save V_new + if save_new_value: + T.copy(V_new_fragment, dst=V_new_shared) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) + # use_g + if use_g: + G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] + for i_s2, i_v in T.Parallel(block_S, block_DV): + G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh] + T.copy(G_shared, G_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): + with T.Then(): + V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( + (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695 + ) + with T.Else(): + V_new_fragment[i_s2, i_v] = 0 + G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) + for i_k, i_v in T.Parallel(DK, block_DV): + b_h_fragment[i_k, i_v] *= G_last_local[0] + + # Update intermediate results + T.copy(V_new_fragment, V_new_shared) + T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True) + + T.copy(b_h_fragment, b_h_shared) + + # Save final state + if store_final_state: + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + + # fla ref + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) + + # tilelang + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) + # (zhengju) If you want to print the generated cuda code, you can uncomment the following line + # print("CUDA Code:\n", kernel.get_kernel_source()) + + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) + tilelang_time = do_bench(kernel, K, W, U, G, initial_state) + + # check correctness + try: + h_ref_fp32 = h_ref.to(torch.float32) + h_tilelang_fp32 = h_tilelang.to(torch.float32) + assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False) + print("tilelang chunk gated delta rule fwd h passed √") + except Exception as e: + print("tilelang chunk gated delta rule fwd h failed ✗") + print(e) + + try: + final_state_ref_fp32 = final_state_ref.to(torch.float32) + final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32) + assert_similar( + final_state_ref_fp32, + final_state_tilelang_fp32, + eps=1e-5, + name="tilelang chunk gated delta rule fwd final_state", + raise_assert=False, + ) + print("tilelang chunk gated delta rule fwd final_state passed √") + except Exception as e: + print("tilelang chunk gated delta rule fwd final_state failed ✗") + print(e) + + try: + V_new_ref_fp32 = V_new_ref.to(torch.float32) + V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) + assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False) + print("tilelang chunk gated delta rule fwd V_new passed √") + except Exception as e: + print("tilelang chunk gated delta rule fwd V_new failed ✗") + print(e) + + print(f"tilelang time: {tilelang_time} ms") + print(f"fla time: {fla_time} ms") + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + DV=128, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + use_g=True, + use_initial_state=False, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_chunk_o.py b/tilelang/original/examples/gdn/example_chunk_o.py new file mode 100644 index 0000000000000000000000000000000000000000..81536815923b294c38a02540f436670fc71798a4 --- /dev/null +++ b/tilelang/original/examples/gdn/example_chunk_o.py @@ -0,0 +1,246 @@ +# Reference: fla/ops/common/chunk_o.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_o import chunk_fwd_o +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + BS = chunk_size + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + return Q, K, V, HIDDEN, G + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, +): + O = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return O + + +@tilelang.jit(out_idx=[-1]) +def tilelang_chunk_fwd_o( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + # kernel config + block_S=64, + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + assert chunk_size == block_S, "chunk_size must be equal to block_S" + BS = chunk_size + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + H_shape = (B, S // BS, H, DK, DV) + G_shape = (B, S, H) + O_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + HIDDEN: T.Tensor(H_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + O: T.Tensor(O_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh): + bb, bh = bbh // H, bbh % H + Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + H_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + O_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + O_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") + G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) + + T.annotate_layout( + { + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + H_shared: tilelang.layout.make_swizzled_layout(H_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.clear(A_fragment) + T.clear(O_fragment) + T.disable_warp_group_reg_alloc() + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared) + T.gemm(Q_shared, H_shared, O_fragment) + T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) + + if use_g: + for i_s in T.Parallel(block_S): + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + # T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) + for i_s, i_v in T.Parallel(block_S, block_DV): + O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * T.exp(G_shared[i_s]) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(G_diff_local[i_s1, i_s2] <= 0): + with T.Then(): + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) + with T.Else(): + A_fragment[i_s1, i_s2] = 0 + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 < i_s2): # noqa: SIM117 + with T.Then(): + A_fragment[i_s1, i_s2] = 0 + + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) + T.copy(A_fragment, A_shared) + T.gemm(A_shared, V_shared, O_fragment) + + for i_s, i_v in T.Parallel(block_S, block_DV): + O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale + + T.copy(O_fragment, O_shared) + T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + use_g, + block_DK, + block_DV, + threads, + num_stages, +): + input_dtype_torch = getattr(torch, input_dtype) + output_dtype_torch = getattr(torch, output_dtype) + accum_dtype_torch = getattr(torch, accum_dtype) + gate_dtype_torch = getattr(torch, gate_dtype) + Q, K, V, HIDDEN, G = prepare_input( + B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch + ) + scale = 1.0 / DK**0.5 + + O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) + O_ref = chunk_fwd_o(Q, K, V, HIDDEN, G, scale, chunk_size=chunk_size) + + block_S = chunk_size + O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) + O_tilelang = kernel(Q, K, V, HIDDEN, G) + + try: + torch.testing.assert_close(O_tilelang, O_ref, rtol=1e-2, atol=1e-2) + print("tilelang chunk fwd o passed √") + except Exception as e: + print("tilelang chunk fwd o failed ✗") + print(e) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + DV=128, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + use_g=True, + block_DK=128, + block_DV=128, + threads=128, + num_stages=1, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_chunk_o_bwd.py b/tilelang/original/examples/gdn/example_chunk_o_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..97e2f4f01e08f29fcb845df5af728c7bf00b2728 --- /dev/null +++ b/tilelang/original/examples/gdn/example_chunk_o_bwd.py @@ -0,0 +1,526 @@ +# Reference: fla/ops/common/chunk_o.py + +import math +import sys # noqa: F401 + +import tilelang +import tilelang.language as T +from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_o import chunk_bwd_dqkwg +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +from test_utils import assert_similar + +torch.random.manual_seed(0) +# torch.set_printoptions(profile="full") + + +def prepare_input_fake( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + h = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda() + G = torch.ones(B, S, H, dtype=gate_dtype).cuda() + dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda() + W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + return Q, K, V, h, G, dO, dh, dv, W + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda() + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + return Q, K, V, h, G, dO, dh, dv, W + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, + block_DK, +): + assert DK == 32 and block_DK == 32 or DK > 32 and block_DK >= 64, "When DK > 32, block_DK must be >= 64" + NK = math.ceil(DK / block_DK) + dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dw = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dg = torch.empty(NK, B, S, H, dtype=gate_dtype).cuda() + return dq, dk, dw, dg + + +# @register_cuda_postproc_callback +# def tilelang_callback_cuda_postproc(code, _): +# cuda_code = open("../debug/chunk_o_bwd3.log", "r").read() +# code = cuda_code +# return code + + +@tilelang.jit( + out_idx=[-4, -3, -2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def tilelang_chunk_o_bwd_dqkwg( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_dw=True, + # kernel config + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + BS = S // block_S + NK = math.ceil(DK / block_DK) + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + h_shape = (B, BS, H, DK, DV) + G_shape = (B, S, H) + dO_shape = (B, S, H, DV) + dh_shape = (B, BS, H, DK, DV) + dv_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + + dq_shape = (B, S, H, DK) + dk_shape = (B, S, H, DK) + dw_shape = (B, S, H, DK) + dg_shape = (NK, B, S, H) + + @T.prim_func + def kernel( + # input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dh: T.Tensor(dh_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + # output + dq: T.Tensor(dq_shape, dtype=output_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dw: T.Tensor(dw_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): + bb, bh = bbh // H, bbh % H + + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + k_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + ds_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype) + dg_shared_1 = T.alloc_shared((block_S,), dtype=gate_dtype) + dg_shared_2 = T.alloc_shared((block_S,), dtype=gate_dtype) + dk_shared = T.alloc_shared((block_S, block_DK), dtype=accum_dtype) + + ds_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + ds_fragment_positive = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + ds_fragment_positive_transpose = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment_2 = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + q_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype) + k_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype) + + dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype) + dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_last_local = T.alloc_local((2,), dtype=gate_dtype) + dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype) + dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype) + dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype) + dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared") + G_last_local = T.alloc_local((1,), dtype=gate_dtype) + + T.use_swizzle(10) + + T.annotate_layout( + { + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + h_shared: tilelang.layout.make_swizzled_layout(h_shared), + dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + } + ) + + T.clear(dg_last_local) + T.clear(G_last_local) + T.clear(G_shared) + T.clear(q_fragment) + T.clear(k_fragment) + T.clear(dg_last_fragment) + + T.clear(ds_fragment) + T.clear(dq_fragment) + T.clear(dk_fragment) + T.clear(dw_fragment) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared) + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) + + if use_g: + T.clear(dg_last_fragment_scalar) + # FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result + # for i_kv in T.Parallel(block_DK * block_DV): + # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] + for i_kv in T.Parallel(block_DK * block_DV): + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] + T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) + dg_last_local[0] += dg_last_fragment_scalar[0] + + T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True) + T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True) + T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) + + if use_dw: + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared) + T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) + + if use_dw: + for i_s, i_k in T.Parallel(block_S, block_DK): + dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared) + T.copy(q_shared, q_fragment) + T.copy(k_shared, k_fragment) + + if use_g: + T.clear(dg_fragment) + T.clear(dg_fragment_2) + for i_s, i_k in T.Parallel(block_S, block_DK): + G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh] + G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh] + # Use gmem directly instead of local register + dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) + + for i_s, i_k in T.Parallel(block_S, block_DK): + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale + T.clear(dg_fragment_reduce_tmp) + for i_s, i_k in T.Parallel(block_S, block_DK): + dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) + + for i_s, i_k in T.Parallel(block_S, block_DK): + with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): + with T.Then(): + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh]) + with T.Else(): + dk_fragment[i_s, i_k] = 0 + T.clear(dg_fragment_reduce_tmp) + for i_s, i_k in T.Parallel(block_S, block_DK): + dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k]) + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) + + # FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result + T.copy(dk_fragment, dk_shared) + T.clear(dg_last_fragment_scalar_2) + for i_sk in T.Parallel(block_S * block_DK): + i_s, i_k = i_sk // block_DK, i_sk % block_DK + dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k] + T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False) + dg_last_local[1] = dg_last_fragment_scalar_2[0] + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): + with T.Then(): + ds_fragment[i_s1, i_s2] = ( + ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale + ) + with T.Else(): + ds_fragment[i_s1, i_s2] = 0 + + T.clear(ds_fragment_positive) + T.clear(ds_fragment_positive_transpose) + T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] + + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) + T.copy(dg_fragment, dg_shared_1) + + # We should transpose the matrix because the reduce_sum statement can only reduce along the last dimension + for i_s1, i_s2 in T.Parallel(block_S, block_S): + ds_fragment_positive_transpose[i_s2, i_s1] = ds_fragment_positive[i_s1, i_s2] + + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(ds_fragment_positive_transpose, dg_fragment_2, dim=1, clear=False) + T.copy(dg_fragment_2, dg_shared_2) + + for i_s in T.Parallel(block_S): + dg_fragment_final[i_s] = dg_shared_1[i_s] - dg_shared_2[i_s] + + T.copy(ds_fragment, ds_shared) + T.gemm(ds_shared, k_shared, dq_fragment) + T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True) + + for i_s in T.Parallel(block_S): + with T.If(i_s >= block_S - 1): # noqa: SIM117 + with T.Then(): + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] + + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + for i_s in T.Parallel(block_S): + dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] + + else: + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 < i_s2): # noqa: SIM117 + with T.Then(): + ds_fragment[i_s1, i_s2] = 0 + T.clear(dk_fragment_2) + T.copy(ds_fragment, ds_shared) + T.gemm(ds_shared, k_shared, dq_fragment) + T.gemm(ds_shared, q_shared, dk_fragment_2, transpose_A=True) + for i_s, i_k in T.Parallel(block_S, block_DK): + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + return kernel + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_dw=True, + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dq_ref, dk_ref, dw_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) + + # ref + if use_g: + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + else: + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + + # tilelang + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_dw, + block_DK, + block_DV, + threads, + num_stages, + ) + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) + + if use_g: + dg_tilelang = dg_tilelang.sum(dim=0) + + # check + try: + assert_similar(dq_ref, dq_tilelang, 1e-5, "tilelang chunk o bwd dq") + print("tilelang chunk o bwd dq passed √") + except Exception as e: + print("tilelang chunk o bwd dq failed ✗") + print(e) + + try: + assert_similar(dk_ref, dk_tilelang, 1e-5, "tilelang chunk o bwd dk") + print("tilelang chunk o bwd dk passed √") + except Exception as e: + print("tilelang chunk o bwd dk failed ✗") + print(e) + + if use_g: + try: + assert_similar(dg_ref, dg_tilelang, 1e-5, "tilelang chunk o bwd dg") + print("tilelang chunk o bwd dg passed √") + except Exception as e: + print("tilelang chunk o bwd dg failed ✗") + print(e) + + if use_dw: + try: + assert_similar(dw_ref, dw_tilelang, 1e-5, "tilelang chunk o bwd dw") + print("tilelang chunk o bwd dw passed √") + except Exception as e: + print("tilelang chunk o bwd dw failed ✗") + print(e) + + +def main(): + DK = 128 + DV = 128 + run_test( + B=1, + S=32768, + H=8, + DK=DK, + DV=DV, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + scale=DK**-0.5, + # scale=1, + use_g=True, + use_dw=True, + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_chunk_scaled_dot_kkt.py b/tilelang/original/examples/gdn/example_chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ef17e3f4b50479958692460d4fc785b0e82a18 --- /dev/null +++ b/tilelang/original/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -0,0 +1,197 @@ +# Reference: fla/ops/common/chunk_scaled_dot_kkt.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + +torch.set_printoptions(profile="full") +torch.random.manual_seed(0) + + +def prepare_input( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=accum_dtype).cuda() + return K, Beta, G + + +def prepare_output( + B, + S, + H, + chunk_size, + dtype, +): + BS = chunk_size + A = torch.empty(B, S, H, BS, dtype=dtype).cuda() + return A + + +@tilelang.jit(out_idx=[-1]) +def tilelang_chunk_scaled_dot_kkt_fwd( + # task config + B, + S, + H, + DK, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + use_g=True, + # kernel config + block_S=64, + block_DK=64, + threads=256, + num_stages=0, +): + K_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + G_shape = (B, S, H) + assert chunk_size == block_S, "chunk_size must be equal to block_S" + BS = chunk_size + output_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=accum_dtype), + A: T.Tensor(output_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + # !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared") + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype) + Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype) + A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + + # Tensor used for gated: + G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") + G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) + + T.fill(A_fragment, 0) + T.disable_warp_group_reg_alloc() + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] + T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) + + if use_g: + for i_s in T.Parallel(block_S): + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + for i_s1, i_s2 in T.Parallel(block_S, block_S): + G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): + with T.Then(): + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) + with T.Else(): + A_fragment[i_s1, i_s2] = 0 + else: + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 <= i_s2): # noqa: SIM117 + with T.Then(): + A_fragment[i_s1, i_s2] = 0 + + T.copy(A_fragment, A_shared) + T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + use_g, + block_DK, + threads, + num_stages, +): + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) + A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) + A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) + + # reference + if use_g: + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + else: + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + + # tilelang + block_S = chunk_size + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) + A_tilelang = kernel(K, Beta, G) + + try: + torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2) + print("tilelang chunk scaled dot kkt fwd passed √") + except Exception as e: + print("tilelang chunk scaled dot kkt fwd failed ✗") + print(e) + print("reference cuda kernel:") + print(kernel.get_kernel_source()) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + use_g=True, + block_DK=64, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_cumsum.py b/tilelang/original/examples/gdn/example_cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0760b496458ae334d1baf54934c0526962a529a6 --- /dev/null +++ b/tilelang/original/examples/gdn/example_cumsum.py @@ -0,0 +1,165 @@ +# Util functions for flash linear attention cumsum +# Reference: fla/ops/utils/cumsum.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.utils.cumsum import chunk_local_cumsum_scalar +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + + +@tilelang.jit( + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} +) +def tilelang_chunk_local_cumsum_scalar( + # task config + B, + S, + H, + chunk_size=64, + is_varlen=False, + head_first=False, + reverse=False, + input_dtype=T.float16, + output_dtype=T.float32, + # kernel config + block_S=64, + threads=256, + use_fragment=False, +): + G_shape = (B, H, S) if head_first else (B, S, H) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size == block_S, "chunk_size must be equal to block_S" + + @T.prim_func + def kernel( + G: T.Tensor(G_shape, dtype=input_dtype), + G_new: T.Tensor(G_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") + if head_first: + T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared) + else: + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared) + if use_fragment: + G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") + T.copy(G_shared, G_fragment) + T.cumsum(G_fragment, dim=1, reverse=reverse) + if head_first: + T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) + else: + T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) + else: + T.cumsum(G_shared, dim=1, reverse=reverse) + if head_first: + T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) + else: + T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) + + return kernel + + +def prepare_cumsum_input( + B, + S, + H, + dtype, +): + G = torch.randn(B, S, H, dtype=dtype).cuda() + return G + + +def prepare_cumsum_output( + B, + S, + H, + dtype, +): + G_new = torch.empty(B, S, H, dtype=dtype).cuda() + return G_new + + +def run_test( + B, + S, + H, + chunk_size, + reverse, + head_first, + input_dtype, + output_dtype, + threads, + use_fragment, +): + G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype)) + G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype)) + G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype)) + + # reference cumsum + G_new_ref = chunk_local_cumsum_scalar( + g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype) + ) + + # tilelang cumsum + block_S = chunk_size + kernel = tilelang_chunk_local_cumsum_scalar( + B=B, + S=S, + H=H, + chunk_size=chunk_size, + reverse=reverse, + head_first=head_first, + input_dtype=input_dtype, + output_dtype=output_dtype, + block_S=block_S, + threads=threads, + use_fragment=use_fragment, + ) + torch.cuda.profiler.start() + G_new_tilelang = kernel(G) + torch.cuda.profiler.stop() + try: + torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2) + print("tilelang cumsum passed √") + except Exception as e: + print("tilelang cumsum failed ✗") + print(e) + print("G:") + print(G.view(-1)) + print("G_new_tilelang:") + print(G_new_tilelang.view(-1)) + print("G_new_ref:") + print(G_new_ref.view(-1)) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + chunk_size=64, + reverse=True, + head_first=False, + input_dtype=T.float32, + output_dtype=T.float32, + threads=256, + use_fragment=False, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_wy_fast.py b/tilelang/original/examples/gdn/example_wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac086ca76916afddbb671377b97809546528e40 --- /dev/null +++ b/tilelang/original/examples/gdn/example_wy_fast.py @@ -0,0 +1,220 @@ +# Reference: fla/ops/gated_delta_rule/wy_fast.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + +torch.random.manual_seed(1) + + +def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32): + BS = chunk_size + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda() + return K, V, Beta, G, A + + +def prepare_output( + B, + S, + H, + DK, + DV, + output_dtype, +): + W = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + U = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return W, U + + +@tilelang.jit(out_idx=[-2, -1]) +def tilelang_recompute_w_u_fwd( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + chunk_size, + # kernel config + block_S=64, + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + Beta_shape = (B, S, H) + assert chunk_size == block_S, "chunk_size must be equal to block_S" + BS = chunk_size + G_shape = (B, S, H) + A_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=output_dtype), + W: T.Tensor(K_shape, dtype=output_dtype), + U: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared") + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") + A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype) + W_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, block_DK), dtype=output_dtype) + U_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), + U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), + } + ) + + T.disable_warp_group_reg_alloc() + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) + + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + for i_s, i_v2 in T.Parallel(block_S, block_DV): + U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] + T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) + # First copy to smem, then copy to gmem to reduce U2RU instructions + T.copy(U_fragment, U_shared) + T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] + T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) + # First copy to smem, then copy to gmem to reduce U2RU instructions + T.copy(W_fragment, W_shared) + T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + block_DK, + block_DV, + threads, + num_stages, +): + K, V, Beta, G, A = prepare_input( + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) + W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) + W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) + + # reference + W_ref, U_ref = recompute_w_u_fwd(K, V, Beta, G, A, None) + + # tilelang + block_S = chunk_size + kernel = tilelang_recompute_w_u_fwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + chunk_size, + block_S=block_S, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, + ) + print(kernel.get_kernel_source()) + W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) + + try: + torch.testing.assert_close(W_tilelang, W_ref, rtol=1e-2, atol=1e-2) + print("tilelang recompute w passed √") + except Exception as e: + print("tilelang recompute w failed ✗") + print(e) + try: + torch.testing.assert_close(U_tilelang, U_ref, rtol=1e-2, atol=1e-2) + print("tilelang recompute u passed √") + except Exception as e: + print("tilelang recompute u failed ✗") + print(e) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + DV=128, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + gate_dtype=T.float32, + accum_dtype=T.float32, + block_DK=64, + block_DV=32, + threads=128, + num_stages=3, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/example_wy_fast_bwd_split.py b/tilelang/original/examples/gdn/example_wy_fast_bwd_split.py new file mode 100644 index 0000000000000000000000000000000000000000..de8afc2b7770432db2535c05c69a852b1af802da --- /dev/null +++ b/tilelang/original/examples/gdn/example_wy_fast_bwd_split.py @@ -0,0 +1,535 @@ +# Reference: fla/ops/gated_delta_rule/wy_fast.py + +import sys # noqa: F401 + +import tilelang +import tilelang.language as T + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id 00000000 +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(0) +torch.set_printoptions(profile="full") + + +def prepare_input_fake( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = chunk_size + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + Beta = torch.ones(B, S, H, dtype=input_dtype).cuda() + G = torch.ones(B, S, H, dtype=gate_dtype).cuda() + A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda() + dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + return K, V, Beta, G, A, dw, du + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = chunk_size + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + V = F.normalize(V, dim=-1, p=2) + Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda() + dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return K, V, Beta, G, A, dw, du + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda() + dg = torch.empty(B, S, H, dtype=gate_dtype).cuda() + return dk, dv, dbeta, dg + + +@tilelang.jit( + out_idx=[-5, -4, -3, -2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def tilelang_wy_fast_bwd( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, +): + block_S = chunk_size + BS = block_S + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + Beta_shape = (B, S, H) + G_shape = (B, S, H) + A_shape = (B, S, H, BS) + dw_shape = (B, S, H, DK) + du_shape = (B, S, H, DV) + + dk_shape = (B, S, H, DK) + dv_shape = (B, S, H, DV) + dbeta_shape = (B, S, H) + dg_shape = (B, S, H) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + # output + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared_beta_g = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + V_shared_beta = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype) + G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype) + dw_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + du_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment_beta_g = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_beta = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype) + dbeta_fragment_v = T.alloc_fragment((block_S,), dtype=accum_dtype) + dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dbeta_fragment_reduce_tmpv = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype) + + T.use_swizzle(10) + + T.clear(dA_fragment) + T.clear(dk_fragment) + T.clear(dk_fragment_beta_g) + T.clear(dv_fragment) + T.clear(dv_fragment_beta) + T.clear(dbeta_fragment_k) + T.clear(dbeta_fragment_v) + T.clear(dg_fragment) + + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + G_shared_exp[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) + + # Update dk + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared) + T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) + T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + # for i_s, i_k2 in T.Parallel(block_S, block_DK): + # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) + + # for i_s, i_k2 in T.Parallel(block_S, block_DK): + # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dg_fragment_reduce_tmp[i_s, i_k2] = ( + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + ) + T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) + + # correct dk + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + # Update dv + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + for i_s, i_v2 in T.Parallel(block_S, block_DV): + V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] + T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared) + T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) + T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) + for i_s, i_v2 in T.Parallel(block_S, block_DV): + dv_fragment[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * Beta_shared[i_s] + # for i_s, i_v2 in T.Parallel(block_S, block_DV): + # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] + for i_s, i_v2 in T.Parallel(block_S, block_DV): + dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] + T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) + + T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + # Temporary store dbeta, dg and dA + for i_s in T.Parallel(block_S): + dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] + dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] + # correct dA + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :]) + + return kernel + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) +def tilelang_wy_fast_bwd_split( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, +): + block_S = chunk_size + BS = block_S + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + Beta_shape = (B, S, H) + G_shape = (B, S, H) + A_shape = (B, S, H, BS) + dw_shape = (B, S, H, DK) + du_shape = (B, S, H, DV) + + dk_shape = (B, S, H, DK) + dv_shape = (B, S, H, DV) + dbeta_shape = (B, S, H) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), + dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), + dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + dA_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dA_A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dA_A_fragment_1 = T.alloc_fragment((block_S,), dtype=accum_dtype) + dA_A_fragment_2 = T.alloc_fragment((block_S,), dtype=accum_dtype) + dk_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + dk_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment_beta = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype) + dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype) + G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype) + + T.clear(dbeta_fragment_reduce_tmpk) + T.clear(dbeta_fragment_k) + T.clear(dA_A_fragment_1) + T.clear(dA_A_fragment_2) + + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + for i_s in T.Parallel(block_S): + G_shared_exp[i_s] = T.exp(G_shared[i_s]) + + # Load intermediate results + # for i_s in T.Parallel(block_S): + # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] + # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] + T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared) + # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + + # Update dA + T.copy(dA_shared, dA_fragment) + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 <= i_s2): # noqa: SIM117 + with T.Then(): + dA_fragment[i_s1, i_s2] = 0 + T.copy(dA_fragment, dA_shared) + T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True) + T.copy(dA_fragment, dA_shared) + T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 <= i_s2): + with T.Then(): + dA_fragment[i_s1, i_s2] = 0 + with T.Else(): + dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2] + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): + with T.Then(): + dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) + with T.Else(): + dA_fragment[i_s1, i_s2] = 0 + T.copy(dA_fragment, dA_shared) + + # acceptable dA diff + # T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + + # Update dk using previous dk + T.clear(A_fragment) + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) + T.copy(dk_shared, dk_fragment) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] + T.gemm(K_shared_beta, K_shared, A_fragment, transpose_B=True) + T.gemm(dA_shared, K_shared, dk_fragment_beta, clear_accum=True) + # for i_s, i_k2 in T.Parallel(block_S, block_DK): + # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] + T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) + T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + # Update dg and dbeta + T.copy(A_fragment, A_shared) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dA_A_fragment[i_s1, i_s2] = dA_fragment[i_s1, i_s2] * A_fragment[i_s1, i_s2] + # Note: Reduce operation now not supported in shared memory + # FIXME: reduce will cause incorrect result when dim != -1 + T.reduce_sum(dA_A_fragment, dA_A_fragment_1, dim=1) + T.reduce_sum(dA_A_fragment, dA_A_fragment_2, dim=0) + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dg_A_positive[bb, bs * block_S + i_s1, bh, i_s2] = dA_A_fragment[i_s1, i_s2] + dg_A_negative[bb, bs * block_S + i_s2, bh, i_s1] = dA_A_fragment[i_s1, i_s2] + + for i_s in T.Parallel(block_S): + dbeta_k[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, +): + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + BS = chunk_size + dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() + dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() + dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + + # ref + dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None) + + # tilelang + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) + torch.cuda.synchronize() + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) + torch.cuda.synchronize() + + dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) + + from test_utils import assert_similar + + assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) + assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) + assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) + assert_similar(dg_ref, dg_tilelang, eps=1e-5, name="dg", raise_assert=False) + + +def main(): + DK = 128 + DV = 128 + run_test( + B=1, + S=32768, + H=8, + DK=DK, + DV=DV, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + block_DK=32, + block_DV=32, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gdn/test_example_gdn_compilation.py b/tilelang/original/examples/gdn/test_example_gdn_compilation.py new file mode 100644 index 0000000000000000000000000000000000000000..e749fa0874ee5dbddecbabe23ce1c72f74662647 --- /dev/null +++ b/tilelang/original/examples/gdn/test_example_gdn_compilation.py @@ -0,0 +1,320 @@ +import torch +import tilelang.testing +from tilelang import language as T + +B = 1 +S = 1024 # small but for test only. +H = 32 +DK = 128 +DV = 128 +input_dtype = T.bfloat16 +output_dtype = T.bfloat16 +accum_dtype = T.float32 +gate_dtype = T.float32 +state_dtype = T.float32 +chunk_size = 64 +use_g = True +use_initial_state = True +store_final_state = True +use_final_state_gradient = True +save_new_value = True +block_DK = 64 +block_DV = 32 +threads = 128 +num_stages = 1 + + +def test_example_wy_fast_compilation(): + from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input + + K, V, Beta, G, A = prepare_input( + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) + # tilelang + block_S = chunk_size + kernel = tilelang_recompute_w_u_fwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + chunk_size, + block_S=block_S, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, + ) + print(kernel.get_kernel_source()) + W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) + + +def test_example_wy_fast_bwd_split_compilation(): + from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output + + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + BS = chunk_size + dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() + dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() + dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + + # tilelang + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) + torch.cuda.synchronize() + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) + torch.cuda.synchronize() + + dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) + + +def test_example_chunk_o_compilation(): + from example_chunk_o import tilelang_chunk_fwd_o, prepare_input + + Q, K, V, HIDDEN, G = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + scale = 1.0 / DK**0.5 + block_S = chunk_size + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) + O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 + + +def test_example_chunk_o_bwd_compilation(): + from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input + + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + True, + block_DK, + block_DV, + threads, + num_stages, + ) + + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 + if use_g: + dg_tilelang = dg_tilelang.sum(dim=0) + + +def test_example_chunk_scaled_dot_kkt_compilation(): + from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input + + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) + block_S = chunk_size + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) + A_tilelang = kernel(K, Beta, G) # noqa: F841 + + +def test_example_cumsum_compilation(): + from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output + + G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) + G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) + block_S = chunk_size + kernel = tilelang_chunk_local_cumsum_scalar( + B=B, + S=S, + H=H, + chunk_size=chunk_size, + reverse=False, + head_first=False, + input_dtype=gate_dtype, + output_dtype=gate_dtype, + block_S=block_S, + threads=threads, + use_fragment=False, + ) + G_new_tilelang = kernel(G) # noqa: F841 + + +def test_example_chunk_delta_h_compilation(): + from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input + + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + block_DK, + block_DV, + threads, + num_stages, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 + + +def test_example_chunk_delta_bwd_compilation(): + from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input + + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gdn/test_utils.py b/tilelang/original/examples/gdn/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3588551ce39ad1bee4267b727979336d73341561 --- /dev/null +++ b/tilelang/original/examples/gdn/test_utils.py @@ -0,0 +1,38 @@ +import torch + + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): + x_mask = torch.isfinite(x) + y_mask = torch.isfinite(y) + if not torch.all(x_mask == y_mask): + print_red_warning(f"{name} Error: isfinite mask mismatch") + if raise_assert: + raise AssertionError + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") + if raise_assert: + raise AssertionError + x = x.masked_fill(~x_mask, 0) + y = y.masked_fill(~y_mask, 0) + sim = calc_sim(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff}") + if raise_assert: + raise AssertionError + else: + print(f"{name} {data} passed") diff --git a/tilelang/original/examples/gemm/README.md b/tilelang/original/examples/gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ab7fb6614654e4594484d2e72ba9a7703b2706b --- /dev/null +++ b/tilelang/original/examples/gemm/README.md @@ -0,0 +1,452 @@ +# TileLang GEMM (Matrix Multiplication) Examples + +TileLang is a domain-specific language designed to simplify the process of writing GPU kernels. It provides high-level abstractions for memory allocation, scheduling, and tiling, which are critical for achieving maximum performance on modern hardware architectures like NVIDIA GPUs. This README demonstrates how to write and optimize a matrix multiplication (GEMM) kernel using TileLang. + +## Table of Contents + +- [Table of Contents](#table-of-contents) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Installation](#installation) +- [Simple GEMM Example](#simple-gemm-example) + - [Code Walkthrough](#code-walkthrough) + - [Compiling and Profiling](#compiling-and-profiling) +- [Advanced GEMM Features](#advanced-gemm-features) + - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) + - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) + - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) +- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) +- [Verifying Correctness](#verifying-correctness) +- [Fine-grained MMA Computations](#fine-grained-mma-computations) + - [Example Workflow](#example-workflow) + - [Summary](#summary) +- [References](#references) + +--- + +## Getting Started + +### Prerequisites + +- **Python 3.8+** +- **NVIDIA GPU** with a recent CUDA toolkit installed +- **PyTorch** (optional, for easy correctness verification) +- **tilelang** +- **bitblas** (optional; used for swizzle layout utilities in the advanced examples) + +### Installation + +```bash +pip install tilelang bitblas +``` + +*(Adjust accordingly if you are installing from source or using a different environment.)* + +--- + +## Simple GEMM Example + +Below is a basic matrix multiplication (GEMM) example demonstrating how TileLang handles buffer allocation, tiling, and kernel dispatch. For simplicity, we'll multiply two 1024×1024 matrices using 128 threads/block. + +```python +import tilelang +from tilelang import Profiler +import tilelang.language as T + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Define a grid with enough blocks to cover M×N + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + + # Allocate shared memory for the current tile of A and B + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + + # Allocate a local (register) fragment for partial accumulations + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Initialize the local accumulation buffer to zero + T.clear(C_local) + + # Loop over the K dimension in block_K chunks, using a 3-stage pipeline + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy from global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + + # Perform a matrix multiply-accumulate on the tile + T.gemm(A_shared, B_shared, C_local) + + # Copy the accumulated result from local memory (C_local) to global memory (C) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main +``` + +### Code Walkthrough + +1. **Define the Kernel Launch Configuration:** + ```python + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + ``` + This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads. + +2. **Shared Memory Allocation:** + ```python + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + ``` + Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access. + +3. **Local Fragment Accumulation:** + ```python + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + ``` + Partial results are stored in registers (or local memory) to reduce writes to global memory. + +4. **Pipelined Loading and GEMM:** + ```python + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(...) + T.gemm(...) + ``` + Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation. + +5. **Copy Out the Results:** + ```python + T.copy(C_local, C[by * block_M, bx * block_N]) + ``` + Writes the final computed tile from registers/shared memory to global memory. + +### Compiling and Profiling + +```python +func = matmul(1024, 1024, 1024, 128, 128, 32) +print(func) # Prints an IR-like representation of the TileLang kernel + +artifact = tilelang.lower(func) + +profiler = Profiler(artifact.rt_mod, artifact.params, result_idx=[2]) + +import torch +a = torch.randn(1024, 1024).cuda().half() +b = torch.randn(1024, 1024).cuda().half() + +c = profiler(a, b) +ref_c = a @ b + +# Validate results +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +# Get CUDA Kernel Source +print(artifact.kernel_source) +``` + +--- + +## Advanced GEMM Features + +### Custom Memory Layout / Swizzling + +**Swizzling** rearranges data in shared memory or global memory to mitigate bank conflicts, improve cache utilization, and better match the GPU’s warp execution pattern. TileLang provides helper functions like `make_swizzle_layout` to annotate how buffers should be laid out in memory. + +### Parallel Copy and Auto-Pipelining + +- **Parallel Copy** allows you to distribute the copy of a block tile across all threads in a block, speeding up the transfer from global memory to shared memory. +- **Auto-Pipelining** uses multiple stages to overlap copying with computation, reducing idle cycles. + +### Rasterization for L2 Cache Locality + +Enabling **swizzle (rasterization)** at the kernel level can improve data reuse and reduce cache thrashing in L2. This is especially important when matrices are large. + +--- + +## Enhanced GEMM Example with Annotations + +Below is a more advanced snippet that showcases how to apply memory layouts, enable swizzling, and parallelize the copy operations to maximize performance: + +```python +import tilelang.language as T +# `make_mma_swizzle_layout` is a python-defined layout function +# that helps align data for MMA (Matrix Multiply-Accumulate) operations. +from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + # Allocate shared and local fragments + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Annotate memory layout + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Enable swizzle-based rasterization for better L2 locality + T.use_swizzle(panel_size=10, enable=True) + + # Clear the local accumulation buffer + T.clear(C_local) + + # Pipelined iteration over K dimension + for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + T.copy(A[by * block_M, idx * block_K], A_shared) + + # Parallel copy tile of B + for ko, j in T.Parallel(block_K, block_N): + B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j] + + # Perform local GEMM on the shared-memory tiles + T.gemm(A_shared, B_shared, C_local) + + # Copy the result tile back + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main +``` + +**Key Differences vs. Basic Example** +1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). +2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. +3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. + +--- + +## Verifying Correctness + +Once you compile and load your kernel into a runtime module (`rt_mod`), you can use tools like **PyTorch** to easily create random matrices on the GPU, run your TileLang kernel, and compare the results to a reference implementation (e.g., `torch.matmul` or `@` operator). + +```python +import torch + +# Suppose your compiled kernel is in rt_mod +profiler = Profiler(rt_mod, params, result_idx=[2]) + +A = torch.randn(1024, 1024).cuda().half() +B = torch.randn(1024, 1024).cuda().half() + +C_tilelang = profiler(A, B) +C_ref = A @ B + +torch.testing.assert_close(C_tilelang, C_ref, rtol=1e-2, atol=1e-2) +print("Results match!") +``` + +--- + +## Fine-grained MMA Computations + +For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. + +### Example Workflow + +```python +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] +``` + +1. **Set Up Tile Sizes and Thread Bindings** + Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID). + +2. **Allocate Warp-local Fragments** + Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like: + ```python + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + ``` + Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles. + +3. **Load Data via `ldmatrix`** + Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well: + ```python + for ki in T.serial(0, (block_K // micro_size_k)): + # Warp-synchronous load for A + mma_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Warp-synchronous load for B + mma_emitter.ldmatrix_b(B_local, B_shared, ki) + ``` + Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers. + +4. **Perform the MMA Instruction** + After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially: + \[ + C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}} + \] + where each thread in the warp calculates a small portion of the final tile. For instance: + ```python + mma_emitter.mma(A_local, B_local, C_local) + ``` + Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel. + +5. **Store Results via `stmatrix`** + Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet: + ```python + mma_emitter.stmatrix(C_local, C_shared) + ``` + orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer. + +### Summary + +By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with manual thread bindings and memory allocations, you can replicate the control and performance of raw CUDA at the TileLang level. This approach is best suited for expert users who are comfortable with GPU warp-level programming, since it does require a deep understanding of hardware concurrency, memory hierarchies, and scheduling. However, the payoff can be significant for performance-critical paths, where every byte of bandwidth and every cycle of latency must be carefully orchestrated. + +--- + +## References + +- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. +- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. +- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul. diff --git a/tilelang/original/examples/gemm/example_gemm.py b/tilelang/original/examples/gemm/example_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..906a55d5d00fb516c00a10bacdf42c916a23bdb3 --- /dev/null +++ b/tilelang/original/examples/gemm/example_gemm.py @@ -0,0 +1,61 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + # benchmark + profiler = kernel.get_profiler() + latency = profiler.do_bench(backend="cupti") + # latency = profiler.do_bench() + print(f"tilelang Latency: {latency}ms") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm/example_gemm_autotune.py b/tilelang/original/examples/gemm/example_gemm_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..ca322217341f5745193372751ad00a009d9f42c2 --- /dev/null +++ b/tilelang/original/examples/gemm/example_gemm_autotune.py @@ -0,0 +1,239 @@ +import argparse +import itertools +import tilelang as tl +import tilelang.language as T +from tilelang.autotuner import AutoTuner +from tilelang.carver.template import MatmulTemplate +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +from tilelang.carver.roller.rasterization import NoRasterization +import torch + + +def ref_program(A, B): + """ + Compute the matrix product of A and the transpose of B. + + A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes. + """ + return A @ B.T + + +def get_configs(M, N, K, with_roller=False, topk=20): + """ + Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. + + When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended + configurations (device-specific TensorCore-friendly tilings). Each returned dict contains: + - block_M, block_N, block_K: tile sizes + - num_stages: pipeline staging (0 means no explicit staging) + - thread_num: total threads used for the block + - enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling) + + When with_roller is False this returns the Cartesian product of a fixed set of candidate + parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag. + + Parameters: + M, N, K (int): GEMM dimensions used to generate valid tile sizes. + with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints; + otherwise use a predefined candidate grid. + topk (int): Maximum number of roller hints to request when with_roller is True. + + Returns: + List[dict]: A list of configuration dictionaries as described above. + + Raises: + ValueError: if with_roller is True but the roller returns no hints. + """ + if with_roller: + arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + roller_hints = carve_template.recommend_hints(topk=topk) + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + # block_rows, block_cols represents warp partitioning + block_rows, block_cols = block_m // warp_m, block_n // warp_n + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0 + config["thread_num"] = block_rows * block_cols * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + else: + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def get_best_config(M, N, K, with_roller=False): + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, + ): + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( + out_idx=[-1], + target="auto", + ) + .set_profile_args( + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + ) + ) + return autotuner.run(warmup=3, rep=20) + + +def get_heuristic_config() -> dict: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version in {80}: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} + elif sm_version in {90}: + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} + else: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} + + +@tl.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_autotune( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_autotune + + +def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): + use_autotune = True + if use_autotune: + result = get_best_config(M, N, K, with_roller) + print(result.config) + kernel = result.kernel + else: + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + + # benchmark + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench() + ref_latency = profiler.do_bench(ref_program) + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print(f"TileLang latency: {tilelang_latency}") + print(f"Ref latency: {ref_latency}") + print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}") + print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") + args = parser.parse_args() + main(args.m, args.n, args.k, args.use_autotune, args.with_roller) diff --git a/tilelang/original/examples/gemm/example_gemm_intrinsics.py b/tilelang/original/examples/gemm/example_gemm_intrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..746e6ec011d8e44830431198dc03060ba4e5af91 --- /dev/null +++ b/tilelang/original/examples/gemm/example_gemm_intrinsics.py @@ -0,0 +1,185 @@ +from tilelang import tvm as tvm +from tvm import DataType +import tilelang +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@tilelang.jit(out_idx=[2]) +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + # chunk = 32 if in_dtype == T.float16 else 64 + chunk = 32 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def gemm_intrinsics( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Load B into fragment + mma_emitter.ldmatrix_b(B_local, B_shared, ki) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix(C_local, C_shared) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return gemm_intrinsics + + +def ref_program(A, B): + return A @ B.T + + +def main(M=4096, N=4096, K=4096): + in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32 + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + profiler = kernel.get_profiler() + + latency = profiler.do_bench(profiler.func, warmup=25) + + print(latency) + + # Ensure that the latency is not None + assert latency is not None + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main(M=4096, N=4096, K=4096) diff --git a/tilelang/original/examples/gemm/example_gemm_intrinsics_dcu.py b/tilelang/original/examples/gemm/example_gemm_intrinsics_dcu.py new file mode 100644 index 0000000000000000000000000000000000000000..e43bef16d7c3f64044a4a338c48313be5e25fb2e --- /dev/null +++ b/tilelang/original/examples/gemm/example_gemm_intrinsics_dcu.py @@ -0,0 +1,189 @@ +from tilelang import tvm as tvm +from tvm import DataType +import tilelang +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mmac_macro_generator import ( + MatrixCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang import disable_cache + +disable_cache() + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@tilelang.jit(out_idx=[2]) +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + # chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMAC Wrapper to Auto Generate Code for MMAC + mmac_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def gemm_intrinsics( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mmac_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Load B into fragment + mmac_emitter.ldmatrix_b(B_local, B_shared, ki) + + # Perform Matrix Multiplication + mmac_emitter.mmac(A_local, B_local, C_local) + + # Perform STMatrix + mmac_emitter.stmatrix(C_local, C_shared) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + j // micro_size_y, + i // micro_size_x, + i % micro_size_x, + j % micro_size_y, + ] + + return gemm_intrinsics + + +def ref_program(A, B): + return A @ B.T + + +def main(): + M, N, K = 16384, 16384, 16384 + in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + profiler = kernel.get_profiler() + + latency = profiler.do_bench(profiler.func, warmup=25) + + print(latency) + print(kernel.get_kernel_source()) + # Ensure that the latency is not None + assert latency is not None + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm/example_gemm_persistent.py b/tilelang/original/examples/gemm/example_gemm_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..30f55de6a06d06eadabd9461ee0eba1169521764 --- /dev/null +++ b/tilelang/original/examples/gemm/example_gemm_persistent.py @@ -0,0 +1,136 @@ +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +import argparse + + +@tilelang.jit(out_idx=[-1]) +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(10) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return main + + +@tilelang.jit(out_idx=[-1]) +def matmul_persistent( + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True +): + sm_num = driver.get_num_sms() + m_blocks = T.ceildiv(M, block_M) + n_blocks = T.ceildiv(N, block_N) + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 8 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(sm_num, threads=threads) as (block_id): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + for w in T.serial(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + @T.prim_func + def main_persistent_primitive( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(sm_num, threads=threads) as (block_id): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return main_persistent_primitive if use_persistent_primitive else main + + +def ref_program(A, B): + return A @ B + + +def main(M=4096, N=4096, K=4096): + total_flops = 2 * M * N * K + + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + num_stages = 3 + + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("Persistent GEMM: All check passed.") + persistent_latency = persistent_profiler.do_bench(warmup=500) + print(f"Persistent GEMM Latency: {persistent_latency} ms") + print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") + + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("Non-Persistent GEMM: All check passed.") + non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) + print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms") + print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops") + + print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") + args = parser.parse_args() + M, N, K = args.M, args.N, args.K + main(M, N, K) diff --git a/tilelang/original/examples/gemm/example_gemm_schedule.py b/tilelang/original/examples/gemm/example_gemm_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..8663c878d043c86a765be08ccf99ae87ce9d1bee --- /dev/null +++ b/tilelang/original/examples/gemm/example_gemm_schedule.py @@ -0,0 +1,68 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_schedule( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 Cache Locality + T.use_swizzle(panel_size=10) + + # Clear the local buffer + T.clear(C_local) + + # Auto pipeline the computation + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Instead of using + # T.copy(B[k * block_K, bx * block_N], B_shared) + # we can also use Parallel to auto map the thread + # bindings and vectorize the copy operation. + for k, j in T.Parallel(block_K, block_N): + B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_schedule + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm/test_example_gemm.py b/tilelang/original/examples/gemm/test_example_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f69364be64fd95f1aa59b3e8a5f69e1a4c57dcf --- /dev/null +++ b/tilelang/original/examples/gemm/test_example_gemm.py @@ -0,0 +1,26 @@ +import tilelang.testing +import example_gemm_autotune +import example_gemm_intrinsics +import example_gemm_schedule +import example_gemm + + +def test_example_gemm_autotune(): + # enable roller for fast tuning + example_gemm_autotune.main(M=1024, N=1024, K=1024, with_roller=True) + + +def test_example_gemm_intrinsics(): + example_gemm_intrinsics.main(M=1024, N=1024, K=1024) + + +def test_example_gemm_schedule(): + example_gemm_schedule.main() + + +def test_example_gemm(): + example_gemm.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gemm_fp8/README.md b/tilelang/original/examples/gemm_fp8/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9d7011a064fb97fb4c78e8f64267d1a4a7f1e35f --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/README.md @@ -0,0 +1 @@ +**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future. \ No newline at end of file diff --git a/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_amd.py b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_amd.py new file mode 100644 index 0000000000000000000000000000000000000000..93f8c4980c36409e38afb1439b244918eee31748 --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -0,0 +1,116 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import torch_assert_close +import itertools + + +def ref_program(A, B): + return (A.half() @ B.half().T).to(dtype=torch.float32) + + +def manual_check_prog(C, C_ref): + torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1) + + +def supply_prog(args): + a_param, b_param = args + M, K = a_param.shape + N, _ = b_param.shape + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + return [a, b] + + +def get_configs(): + block_Ms = [32, 64, 128] + block_Ns = [32, 64, 128] + block_Ks = [64, 128] + num_stages = [0] + num_threads = [256] + k_packs = [1, 2] + gemm_types = ["ss", "rs"] + + valid_configs = [] + + for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + ) + return valid_configs + + +@tilelang.autotune( + configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog +) +@tilelang.jit(out_idx=[-1]) +def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): + dtype = T.float8_e4m3fnuz + accum_dtype = T.float32 + + @T.prim_func + def gemm_fp8_rs( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + A_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_local) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + @T.prim_func + def gemm_fp8_ss( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + if gemm_type == "ss": + return gemm_fp8_ss + elif gemm_type == "rs": + return gemm_fp8_rs + else: + raise ValueError(f"Invalid gemm_type: {gemm_type}") + + +def test_gemm_fp8(M, N, K): + kernel = fp8_matmul(M, N, K) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + c = kernel(a, b) + ref_c = ref_program(a, b) + torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("passed~") + + +if __name__ == "__main__": + test_gemm_fp8(512, 512, 512) diff --git a/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..1b440a7952a211b61f20e8c7c849d474a89cb0c1 --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -0,0 +1,63 @@ +import torch +import tilelang +import tilelang.language as T + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): + @T.prim_func + def gemm_fp8( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_fp8 + + +def test_gemm_fp8(M, N, K, dtype): + torch_dtype = T.dtype(dtype).as_torch() + + kernel = matmul(M, N, K, 128, 128, 64, dtype) + + a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + + c = kernel(a, b) + + ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype) + + print(c) + print(ref_c) + + diff = calc_diff(c, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + +def main(): + test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5d84d72f16cd7af7e1304bfbacd28c80795035 --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -0,0 +1,81 @@ +import torch +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): + # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. + # if block_K < 128, promote after 128/block_K iters. + # if block_K > 128, promote after every iter. + update_interval = 128 // block_K if block_K < 128 else 1 + + @T.prim_func + def gemm_fp8_2xAcc( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + if (k + 1) % update_interval == 0: + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] + T.clear(C_local) + # Tail processing + if K_iters % update_interval != 0: + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_fp8_2xAcc + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def test_gemm_fp8(M, N, K, dtype): + torch_dtype = T.dtype(dtype).as_torch() + + kernel = matmul(M, N, K, 128, 128, 64, dtype) + + a = torch.rand(M, K, dtype=torch.float16, device="cuda") + a = (100 * (2 * a - 1)).to(dtype=torch_dtype) + b = torch.rand(N, K, dtype=torch.float16, device="cuda") + b = (100 * (2 * b - 1)).to(dtype=torch_dtype) + + c = kernel(a, b) + + ref_c = a.float() @ b.float().T + + diff = calc_diff(c, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + +def main(): + test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py new file mode 100644 index 0000000000000000000000000000000000000000..7ecde7c1b4e03cf7cc6f3f0284f4062d30e25cb5 --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -0,0 +1,228 @@ +import torch +from tilelang import tvm as tvm +import tilelang.testing +from tvm import DataType +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@tilelang.jit(out_idx=[2]) +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + is_float8 = in_dtype in [ + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, + ] + if out_dtype == T.int32 or is_float8: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def gemm_fp8_intrinsic( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return gemm_fp8_intrinsic + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + print(src_code) + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + + C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + + C = profiler(A, B) + + latency = profiler.do_bench(warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def main(): + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e8b360805459bda83b73dd7334b1bd923a201 --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,124 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: + for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/tilelang/original/examples/gemm_fp8/test_example_gemm_fp8.py b/tilelang/original/examples/gemm_fp8/test_example_gemm_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..19a9ee00a7cac624526585fc6caba191037ed46d --- /dev/null +++ b/tilelang/original/examples/gemm_fp8/test_example_gemm_fp8.py @@ -0,0 +1,20 @@ +import tilelang.testing +import example_tilelang_gemm_fp8_2xAcc +import example_tilelang_gemm_fp8_intrinsic +import example_tilelang_gemm_fp8 + + +def test_example_tilelang_gemm_fp8_2xAcc(): + example_tilelang_gemm_fp8_2xAcc.main() + + +def test_example_tilelang_gemm_fp8_intrinsic(): + example_tilelang_gemm_fp8_intrinsic.main() + + +def test_example_tilelang_gemm_fp8(): + example_tilelang_gemm_fp8.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gemm_sm100/README.md b/tilelang/original/examples/gemm_sm100/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28bb611bff167fdf7ba2291833edb91fd3d17beb --- /dev/null +++ b/tilelang/original/examples/gemm_sm100/README.md @@ -0,0 +1,106 @@ +# TileLang SM100 Support (Preview) + +This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality. + +## Current Limitations (Manual Implementation Required) + +### 1. Manual TCGEN5.MMA Management +Users must manually handle TCGEN5MMA operations using: +- `T.alloc_tmem()` - Allocate Tensor Memory +- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting +- Manual synchronization with mbarrier + +### 2. Manual mbarrier Synchronization +TCGEN5MMA is asynchronous and requires explicit synchronization: +```python +mbar = T.alloc_barrier(1) # expect-arrive-count = 1 +T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0) +T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required +``` + +## Examples + +### TCGEN5MMA Example (`gemm_tcgen5mma.py`) +Demonstrates TCGEN5MMA operations with: +- Tensor Memory allocation +- Manual mbarrier synchronization +- TCGEN5MMA gemm operations + +### Traditional MMA Example (`gemm_mma.py`) +Shows standard MMA operations that work across architectures for comparison. + +## Code Example + +The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication: + +```python +import torch +import tilelang +import tilelang.language as T + +@T.prim_func +def main( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), +): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + # 1. Allocate memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory + B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory + C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory + mbar = T.alloc_barrier(1) # mbarrier synchronization primitive + + C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage + C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory + + # 2. Main computation loop + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + # Data loading: global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + + # TCGEN5MMA computation: asynchronous launch, output to Tensor Memory + T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True, + mbar=mbar, wg_wait=-1, clear_accum=k==0) + + # Critical: wait for TCGEN5MMA completion + T.mbarrier_wait_parity(mbar, k%2) + + # 3. Output processing (only subset of threads) + T.copy(C_tmem, C_local) # Tensor Memory → registers + T.copy(C_local, C_shared) # registers → shared memory + + # 4. Write back to global memory + T.copy(C_shared, C[by * block_M, bx * block_N]) +``` + +### Compilation and Usage + +```python +# Parameter setup +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 128, 256, 128 + +# Compile kernel +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required +}) + +# Run test +a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) +b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) +c = jit_kernel(a, b) + +# Verify correctness +ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +# Performance benchmark +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") +``` + diff --git a/tilelang/original/examples/gemm_sm100/gemm_mma.py b/tilelang/original/examples/gemm_sm100/gemm_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..226e33c01e474ec646ba7f7e7ac39c86a2497c6a --- /dev/null +++ b/tilelang/original/examples/gemm_sm100/gemm_mma.py @@ -0,0 +1,94 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + # Copy tile of A + # This is a sugar syntax for parallelized copy + # for i, k in T.Parallel(M, block_K): + # A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[bx * block_N, ko * block_K], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +M = 128 # M = T.dynamic("m") if you want to use dynamic shape +N = 128 +K = 32 +block_M = 128 +block_N = 128 +block_K = 32 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +func = matmul(M, N, K, block_M, block_N, block_K) + +# 2. Compile the kernel into a torch function +# out_idx specifies the index of the output buffer in the argument list +# if out_idx is specified, the tensor will be created during runtime +# target currently can be "cuda" or "hip" or "cpu". +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +print(jit_kernel.get_kernel_source()) +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(N, K, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +c = jit_kernel(a, b) + +print(c) +# Reference multiplication using PyTorch +ref_c = a @ b.T + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/tilelang/original/examples/gemm_sm100/gemm_tcgen5mma.py b/tilelang/original/examples/gemm_sm100/gemm_tcgen5mma.py new file mode 100644 index 0000000000000000000000000000000000000000..523a94fea6737bcd33f879ee8b49aaecaa3740af --- /dev/null +++ b/tilelang/original/examples/gemm_sm100/gemm_tcgen5mma.py @@ -0,0 +1,83 @@ +import torch +import tilelang +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 128, 256, 128 +trans_A, trans_B = False, True +in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float +num_stages = 2 +threads = 256 + +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) + +print(jit_kernel.get_kernel_source()) + +a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) +b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) +c = jit_kernel(a, b) +ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/tilelang/original/examples/gemm_sp/example_custom_compress.py b/tilelang/original/examples/gemm_sp/example_custom_compress.py new file mode 100644 index 0000000000000000000000000000000000000000..7b93f2a779e2e56dd496712abf9fa16363feafa8 --- /dev/null +++ b/tilelang/original/examples/gemm_sp/example_custom_compress.py @@ -0,0 +1,336 @@ +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import randn_semi_sparse +from tilelang.utils.tensor import torch_assert_close + +from triton.testing import do_bench + +import torch + +torch.manual_seed(42) + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, + "h20": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16_custom_compress( + M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout +): + e_factor, e_dtype = (16, T.int16) + + @T.prim_func + def gemm_sp_fp16_custom_compress( + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), T.float16) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16_custom_compress + + +def torch_compress(dense): + """ + A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. + """ + if dense.dim() != 2: + raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + + m, k = dense.shape + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + return (sparse, meta) + + +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 # 4 groups per uint16 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) + + +@tilelang.jit( + out_idx=[1, 2], + pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, + }, +) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(A_sp_shared) + T.clear(E_shared) + # TODO: alloc_var seems buggy here + non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) + non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + non_zero_cnt[0] = 0 + for i in range(elem): + non_zero_elt_log_idx[i] = 0 + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel + + +def main(): + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") + args = parser.parse_args() + kernel = matmul_sp_fp16_custom_compress( + args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout + ) + + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) + + if args.use_torch_compressor: + assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" + a_sparse, e = torch_compress(a) + else: + a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a) + + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * args.m * args.n * args.k + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_sp/example_gemm_sp.py b/tilelang/original/examples/gemm_sp/example_gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..10f524adbc791867f94a8af203ce787217a13183 --- /dev/null +++ b/tilelang/original/examples/gemm_sp/example_gemm_sp.py @@ -0,0 +1,133 @@ +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import compress, randn_semi_sparse +from tilelang.contrib import nvcc +from triton.testing import do_bench + +import torch + +arch = nvcc.get_target_compute_version() + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, + "h20": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): + e_factor, e_dtype = ARCH_INFO[arch] + + @T.prim_func + def gemm_sp_fp16( + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), T.float16) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch), + } + ) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16 + + +def main(): + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") + args = parser.parse_args() + kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) + + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) + + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch) + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) + print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}") + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * args.m * args.n * args.k + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_sp/test_example_gemm_sp.py b/tilelang/original/examples/gemm_sp/test_example_gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..fe26df14497e392bd0cdc03b9a3352d4fbd2f24b --- /dev/null +++ b/tilelang/original/examples/gemm_sp/test_example_gemm_sp.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_custom_compress +import example_gemm_sp + + +def test_example_custom_compress(): + example_custom_compress.main() + + +def test_example_gemm_sp(): + example_gemm_sp.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/tilelang/original/examples/gemm_splitk/example_tilelang_gemm_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..62073c5bddfa93959ad21b98c7b7cda3ffcb1e76 --- /dev/null +++ b/tilelang/original/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -0,0 +1,60 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): + splitK = K // split_k + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j]) + + return main + + +def main(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + kernel(a, b, c) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/tilelang/original/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py new file mode 100644 index 0000000000000000000000000000000000000000..83e83b5d2a7fff6eddbed10ddf63a19f8d18d425 --- /dev/null +++ b/tilelang/original/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -0,0 +1,59 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): + splitK = K // split_k + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + + T.atomic_add(C[by * block_M, bx * block_N], C_shared) + + return main + + +def main(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + kernel(a, b, c) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_splitk/test_example_gemm_splitk.py b/tilelang/original/examples/gemm_splitk/test_example_gemm_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..055b09162767d4a208bdec0d7b5ca8ccefec772c --- /dev/null +++ b/tilelang/original/examples/gemm_splitk/test_example_gemm_splitk.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_tilelang_gemm_splitk +import example_tilelang_gemm_splitk_vectorize_atomicadd + + +def test_example_tilelang_gemm_splitk(): + example_tilelang_gemm_splitk.main() + + +def test_example_tilelang_gemm_splitk_vectorize_atomicadd(): + example_tilelang_gemm_splitk_vectorize_atomicadd.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/tilelang/original/examples/gemm_streamk/example_tilelang_gemm_streamk.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec1541ea4aac095dc34b69ea55bdd73d66f4db7 --- /dev/null +++ b/tilelang/original/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -0,0 +1,203 @@ +import torch +import torch.backends +import tilelang +from tilelang import language as T +import math + + +def cdiv(a, b): + return math.ceil(a / b) + + +# disable tf32 +torch.backends.cuda.matmul.allow_tf32 = False + +m = 256 +n = 1024 +k = 512 + +total_sm = 108 + +torch.random.manual_seed(0) +# uniform distribution from -1 to 1 +A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1 +B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1 + +streamk_programs = total_sm +BLOCK_SIZE_M = 16 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 32 +two_tiles = False +M, K = A.shape +N, K = B.shape +# accumulator types +# compute grid (work to do per SM on the first wave) +num_block_m = tilelang.cdiv(M, BLOCK_SIZE_M) +num_block_n = tilelang.cdiv(N, BLOCK_SIZE_N) +iters_per_tile = tilelang.cdiv(K, BLOCK_SIZE_K) +total_tiles = num_block_m * num_block_n + +# Two-tile SK + DP +streamk_tiles = total_tiles % streamk_programs +if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1) + streamk_tiles += streamk_programs + +blocking_tiles = total_tiles - streamk_tiles +streamk_iters = streamk_tiles * iters_per_tile + +streamk_full_tiles = streamk_iters // streamk_programs +streamk_partial_tiles = streamk_iters % streamk_programs + +print(f"{total_tiles=} ") +print(f"{iters_per_tile=} ") + +sm_patition_factor = max(blocking_tiles // total_sm, 1) + + +@tilelang.jit +def tl_matmul_streamk( + M, + N, + K, + streamk_tiles, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + assert not trans_A + A_shape = (M, K) if not trans_A else (K, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + @T.macro + def compute_first_wave( + pid: T.int32, + A_buf: T.Tensor, + A_buf_shared: T.SharedBuffer, + B_buf: T.Tensor, + B_buf_shared: T.SharedBuffer, + C: T.Tensor, + C_local: T.LocalBuffer, + ): + start_iter = T.alloc_fragment((1,), T.int32, "local") + end_iter = T.alloc_fragment((1,), T.int32, "local") + + start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) + last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) + + while start_iter[0] < last_iter: + end_iter[0] = T.min( + start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), + last_iter, + ) + + tile_id = start_iter[0] // iters_per_tile + remain_iters = start_iter[0] % iters_per_tile + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + + T.clear(C_local) + for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): + T.copy( + A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A_buf_shared, + ) + T.copy( + B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B_buf_shared, + ) + T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B) + + # last iteration of the tile always happens before its start on another SM + if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) + + start_iter[0] = end_iter[0] + + @T.macro + def compute_full_tiles( + pid: T.int32, + A_buf: T.Tensor, + A_shared: T.SharedBuffer, + B_buf: T.Tensor, + B_shared: T.SharedBuffer, + C: T.Tensor, + C_local: T.LocalBuffer, + ): + for p in T.serial(sm_patition_factor): + tile_id = pid + streamk_tiles + p * total_sm + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A_buf[pid_m * block_M, k * block_K], A_shared) + T.copy(B_buf[pid_n * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + + @T.prim_func + def main( + A: T.Tensor(A_shape, dtypeAB), + B: T.Tensor(B_shape, dtypeAB), + C: T.Tensor((M, N), dtypeC), + ): + with T.Kernel(streamk_programs, threads=threads) as pid: + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) + + if sm_patition_factor > 0: + compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) + + return main + + +def main(): + kernel = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + T.float16, + T.float16, + T.float32, + 2, + 64, + ) + + print(kernel.get_kernel_source()) + + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + + kernel(A, B, b_c) + + C = torch.matmul(A, B.T) + + print(b_c) + print(C) + torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py b/tilelang/original/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..a26ba74aede947a589923f7d1a57de3a14435de2 --- /dev/null +++ b/tilelang/original/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py @@ -0,0 +1,14 @@ +import tilelang.testing + +from example_tilelang_gemm_streamk import main + + +# not fully supported on sm90 +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_tilelang_gemm_streamk(): + main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/gemv/example_gemv.py b/tilelang/original/examples/gemv/example_gemv.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd0e4dd9f88cc4d74500e9b773cfcdb38999704 --- /dev/null +++ b/tilelang/original/examples/gemv/example_gemv.py @@ -0,0 +1,368 @@ +import argparse +import itertools +import tilelang as tl +import tilelang.language as T +from tvm import DataType +from tilelang.autotuner import autotune +from tilelang import jit + + +def ref_program(A, B): + return A @ B.T + + +@tl.jit(out_idx=[-1]) +def naive_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: + tn = T.get_thread_binding(0) # tn = threadIdx.x + A_shared = T.alloc_shared((BLOCK_K,), dtype) + B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype) + C_reg = T.alloc_local((1,), accum_dtype) + T.clear(C_reg) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for tk in T.serial(BLOCK_K): + A_shared[tk] = A[bk * BLOCK_K + tk] + B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + for tk in T.serial(BLOCK_K): + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype) + C[bn * BLOCK_N + tn] = C_reg[0] + + return main + + +@tl.jit(out_idx=[-1]) +def naive_splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((1,), dtype) + B_local = T.alloc_local((1,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + A_local[0] = A[bk * BLOCK_K + tk] + B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main + + +@tl.jit(out_idx=[-1]) +def splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + reduce_threads: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + TILE_K = T.ceildiv(BLOCK_K, reduce_threads) + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.serial(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main + + +@tl.jit(out_idx=[-1]) +def splitk_gemv_vectorized( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main + + +@tl.jit(out_idx=[-1]) +def splitk_gemv_vectorized_tvm( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + ) + ) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main + + +def get_block_template_configs(): + iter_params = dict( + block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256] + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tl.autotune( + configs=get_block_template_configs(), + warmup=3, + rep=20, +) +@tl.jit( + pass_configs={ + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + out_idx=[2], +) +def gemv_alloc_reducer( + M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float +): + @T.prim_func + def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore + with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: + o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") + T.clear(o_reducer) + for i0_n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages): + a_smem = T.alloc_shared((block_M, block_N), dtype) + T.copy(a[i0_m * block_M, i0_n * block_N], a_smem) + a_frag = T.alloc_fragment((block_M, block_N), dtype) + T.copy(a_smem, a_frag) + x_frag = T.alloc_fragment(block_N, dtype) + T.copy(x[i0_n * block_N], x_frag) + for i1_m, i1_n in T.Parallel(block_M, block_N): + o_reducer[i1_m] += a_frag[i1_m, i1_n] * x_frag[i1_n] + T.finalize_reducer(o_reducer) + T.copy(o_reducer, o[i0_m * block_M]) + + return main + + +def get_thread_template_configs(): + iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_thread_template_configs(), + warmup=3, + rep=20, +) +@jit( + out_idx=[-1], + target="auto", +) +def get_autotuned_kernel( + N, + K, + BLOCK_N=None, + reduce_threads=None, +): + dtype = T.float16 + accum_dtype = T.float32 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + ) + ) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main + + +def check_correctness_and_bench(kernel, N, K, do_bench=True): + profiler = kernel.get_profiler() + profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) + if do_bench: + latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50) + print(f"Torch Latency: {latency} ms") + latency = profiler.do_bench(kernel, warmup=50) + print(f"TileLang Latency: {latency} ms\n") + + +def main(do_bench: bool = True): + parser = argparse.ArgumentParser(description="GEMV Example") + parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") + args, _ = parser.parse_known_args() + N, K = args.n, args.k + check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench) + check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) + + print("Test passed!") + + if do_bench: + best_result = get_autotuned_kernel(N, K) + best_config = best_result.config + kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) + profiler = kernel.get_profiler() + latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500) + print(f"Torch Latency: {latency} ms") + tilelang_thread_latency = profiler.do_bench(kernel, warmup=500) + print(f"TileLang SIMT Latency: {tilelang_thread_latency} ms\n") + kernel = gemv_alloc_reducer(N, K) + profiler = kernel.get_profiler() + tilelang_tile_latency = profiler.do_bench(kernel, warmup=500) + print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/gemv/test_example_gemv.py b/tilelang/original/examples/gemv/test_example_gemv.py new file mode 100644 index 0000000000000000000000000000000000000000..323337a7a6a0f21f79ed4455d8243e3561f3847a --- /dev/null +++ b/tilelang/original/examples/gemv/test_example_gemv.py @@ -0,0 +1,9 @@ +import example_gemv + + +def test_example_gemv(): + example_gemv.main(do_bench=False) + + +if __name__ == "__main__": + test_example_gemv() diff --git a/tilelang/original/examples/grouped_gemm/example_grouped_gemm_bwd.py b/tilelang/original/examples/grouped_gemm/example_grouped_gemm_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..bb57c60731a4b495028612546def8b24324940a8 --- /dev/null +++ b/tilelang/original/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -0,0 +1,239 @@ +import torch +import math +import argparse +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + """ + args: + a (torch.Tensor): Input tensor of shape (M, K). + b (torch.Tensor): Input tensor of shape (G, K, N). + """ + accum_dtype = T.float32 + + @T.prim_func + def kernel( + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore + ): + with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared([block_M, block_K], dtype) + B_shared = T.alloc_shared([block_K, block_N], dtype) + C_local = T.alloc_fragment([block_M, block_N], accum_dtype) + cur_batch_idx = T.alloc_local([1], T.int32) + cur_batch_size = T.alloc_local([1], T.int32) + + m_start_padded = bx * block_M + + for i in range(batch_count): + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + + cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + for i, j in T.Parallel(block_M, block_N): + with T.If(i < actual_rows), T.Then(): + C[m_start + i, by * block_N + j] = C_local[i, j] + + return kernel + + +class _GroupedGEMM(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b, batch_sizes): + block_M = 64 + block_N = 64 + block_K = 64 + padding_M = block_M + num_stages = 2 + threads = 128 + batch_sum = a.shape[0] + batch_count = b.shape[0] + K = a.shape[1] + N = b.shape[2] + + assert a.shape[1] == b.shape[1] + assert batch_sizes.shape[0] == batch_count + assert batch_sizes.sum() == batch_sum + + batch_offsets_list = [0] + batch_padded_offsets_list = [0] + for i in range(batch_count - 1): + batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) + for i in range(batch_count - 1): + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M) + batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32) + + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads) + + o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) + ctx.save_for_backward(a, b, batch_sizes, batch_offsets) + ctx.batch_sum = batch_sum + ctx.batch_count = batch_count + ctx.K = K + return o + + @staticmethod + def backward(ctx, grad_output): + block_M = 64 + block_N = 64 + block_K = 64 + num_stages = 2 + threads = 128 + + M = ctx.K + N = grad_output.shape[1] + + A, B, batch_sizes, batch_offsets = ctx.saved_tensors + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads) + + dB = kernel(A, grad_output, batch_sizes, batch_offsets) + return None, dB, None + + +def ref_program(a, b, batch_sizes): + assert a.shape[0] == sum(batch_sizes) + assert b.shape[0] == len(batch_sizes) + + output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) + + start = 0 + a_list = [] + b_list = [] + for i, size in enumerate(batch_sizes): + end = start + size + part_a = a[start:end] + part_b = b[i] + output[start:end] = torch.mm(part_a, part_b) + + a_list.append(part_a) + b_list.append(part_b) + start = end + + return output + + +def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): + batch_sum = sum(batch_sizes_list) + batch_count = len(batch_sizes_list) + batch_offsets_list = [0] + batch_padded_offsets_list = [0] + for i in range(batch_count - 1): + batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) + for i in range(batch_count - 1): + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M) + A = torch.randn(batch_sum, K, device=device, dtype=dtype) + B = torch.randn(batch_count, K, M, device=device, dtype=dtype) + C = torch.empty(batch_sum, M, device=device, dtype=dtype) + batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32) + batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32) + # print(batch_sizes_tensor) + # print(batch_offsets_tensor) + # print(batch_padded_offsets_tensor) + return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets + + +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + """ + args: + a (torch.Tensor): Input tensor of shape (M, K). + b (torch.Tensor): Input tensor of shape (G, K, N). + """ + accum_dtype = T.float32 + + @T.prim_func + def kernel( + A: T.Tensor([batch_sum, M], dtype), # type: ignore + B: T.Tensor([batch_sum, N], dtype), # type: ignore + C: T.Tensor([batch_count, M, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared([block_K, block_M], dtype) + B_shared = T.alloc_shared([block_K, block_N], dtype) + C_local = T.alloc_fragment([block_M, block_N], accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): + for i, j in T.Parallel(block_K, block_M): + A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0) + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0) + T.gemm(A_shared, B_shared, C_local, transpose_A=True) + + T.copy(C_local, C[bz, bx * block_M, by * block_N]) + + return kernel + + +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): + padding_M = block_M + device = torch.device("cuda") + dtype = torch.float16 + + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype) + + A.requires_grad_(False) + B.requires_grad_(True) + O_ref = ref_program(A, B, batch_sizes) + dO = torch.randn_like(O_ref) + + O_ref.backward(dO, retain_graph=True) + dB_ref, B.grad = B.grad.clone(), None + + GroupedGEMM = _GroupedGEMM.apply + O = GroupedGEMM(A, B, batch_sizes) + O.backward(dO, retain_graph=True) + dB, B.grad = B.grad.clone(), None + + if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2): + print("✅ Tilelang and Torch match") + else: + print("❌ Tilelang and Torch mismatch") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") + args = parser.parse_args() + + batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] + K, M, trans_b = args.K, args.M, args.trans_b + + block_M = 64 + block_N = 128 + block_K = 64 + num_stages = 2 + threads = 256 + + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/tilelang/original/examples/grouped_gemm/example_grouped_gemm_fwd.py b/tilelang/original/examples/grouped_gemm/example_grouped_gemm_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..48d91605145405a94c6e697207bb63ee49bc6a66 --- /dev/null +++ b/tilelang/original/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -0,0 +1,163 @@ +import torch +import argparse +import tilelang +import tilelang.language as T +import math + + +def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): + """ + Perform grouped matrix multiplication using PyTorch. + + Args: + a (torch.Tensor): Input tensor of shape (N, K). + b (torch.Tensor): Input tensor of shape (G, K, M). + batch_sizes (torch.Tensor): 1D tensor containing the sizes of each group. + + Returns: + torch.Tensor: Resulting tensor after grouped matrix multiplication. + """ + assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" + assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes" + + # Initialize output tensor + output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) + + # Perform grouped GEMM + start = 0 + for i, size in enumerate(batch_sizes): + end = start + size + part_a = a[start:end] + part_b = b[i].transpose(0, 1) if trans_b else b[i] + part_out = torch.mm(part_a, part_b) + output[start:end] = part_out + start = end + + return output + + +@tilelang.jit(out_idx=[2]) +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + """ + args: + a (torch.Tensor): Input tensor of shape (M, K). + b (torch.Tensor): Input tensor of shape (G, K, N). + """ + batch_sum = sum(batch_sizes_list) + batch_count = len(batch_sizes_list) + accum_dtype = T.float32 + total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) + + @T.prim_func + def kernel( + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore + ): + with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared([block_M, block_K], dtype) + B_shared = T.alloc_shared([block_K, block_N], dtype) + C_local = T.alloc_fragment([block_M, block_N], accum_dtype) + cur_batch_idx = T.alloc_local([1], T.int32) + cur_batch_size = T.alloc_local([1], T.int32) + + m_start_padded = bx * block_M + + for i in range(batch_count): + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + + cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + for i, j in T.Parallel(block_M, block_N): + with T.If(i < actual_rows), T.Then(): + C[m_start + i, by * block_N + j] = C_local[i, j] + + return kernel + + +def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): + batch_sum = sum(batch_sizes_list) + batch_count = len(batch_sizes_list) + batch_offsets_list = [0] + batch_padded_offsets_list = [0] + for i in range(batch_count - 1): + batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) + for i in range(batch_count - 1): + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) + A = torch.randn(batch_sum, K, device=device, dtype=dtype) + B = torch.randn(batch_count, K, M, device=device, dtype=dtype) + C = torch.empty(batch_sum, M, device=device, dtype=dtype) + batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32) + batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32) + # print(batch_sizes_tensor) + # print(batch_offsets_tensor) + # print(batch_padded_offsets_tensor) + return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets + + +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): + padding_M = block_M + batch_sum = sum(batch_sizes_list) + kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) + # print(kernel.get_kernel_source()) + + device = torch.device("cuda") + dtype = torch.float16 + + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype) + out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) + ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) + # print(out) + # print(ref_output) + if torch.allclose(out, ref_output, rtol=0.01, atol=0.01): + print("✅ Tilelang and Torch match") + else: + print("❌ Tilelang and Torch mismatch") + + if profile: + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) + print(f"Latency: {latency} ms") + print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") + + +def test_grouped_gemm(): + run_tilelang_grouped_gemm([64], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([64, 128, 256], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([63], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([100, 200, 300, 400], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([63, 77, 111, 280], 8192, 8192, 64, 64, 64, False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") + args = parser.parse_args() + + batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] + K, M, trans_b = args.K, args.M, args.trans_b + + block_M = 64 + block_N = 128 + block_K = 64 + num_stages = 2 + threads = 256 + + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/tilelang/original/examples/hadamard_transform/example_hadamard.py b/tilelang/original/examples/hadamard_transform/example_hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..65f463b71bb06c53ab579a1c0389e71b0b1c387e --- /dev/null +++ b/tilelang/original/examples/hadamard_transform/example_hadamard.py @@ -0,0 +1,153 @@ +import tilelang +import tilelang.language as T +from tilelang.intrinsics import make_mma_swizzle_layout + +import math +import argparse +import torch +from torch.nn import functional as F +import scipy + + +def is_pow_of_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +@tilelang.jit(out_idx=[1]) +def hadamard(b, n, dtype): + assert is_pow_of_2(n), "n must be a power of 2" + assert 2 <= n <= 32768, "n must be in [2, 32768]" + elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype] + + logN = int(math.log2(n)) + threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] + thread_elem = n // threads # Each thread is responsible for a chunk of elements + thread_round = int(math.log2(thread_elem)) + + warps = 1 if threads <= 32 else threads // 32 + warp_round = int(math.log2(threads / warps)) + warp_size = threads // warps + + block_round = int(math.log2(warps)) + + exchange_round = n * elem_size // 32768 if n * elem_size > 32768 else 1 # Suppose we use 32KB shared memory at most + thread_elem_in_smem = thread_elem // exchange_round if exchange_round > 1 else thread_elem + + # debug log + # print(f'{threads=}, {thread_round=}') + # print(f'{warps=}, {warp_round=}, {warp_size=}') + # print(f'{block_round=}') + # print(f'{exchange_round=}') + + @T.macro + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int): + tx = T.get_thread_binding(0) + for i in T.serial(round): + tx_stride = 1 << i + another_tx = tx ^ tx_stride + sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] + + for j in T.Pipelined(thread_elem, num_stages=1): + buf[j] = T.tvm_warp_shuffle( + 0xFFFFFFFF, # mask of all threads + local[j], + another_tx % warp_size, + warp_size, + warp_size, + ) + local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) + + @T.prim_func + def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)): + with T.Kernel(b, threads=threads) as bx: + local = T.alloc_local((thread_elem,), dtype) + shared = T.alloc_shared((threads, thread_elem_in_smem), dtype) + T.annotate_layout({shared: make_mma_swizzle_layout(shared)}) + tx = T.get_thread_binding(0) + + # 1. Load from HBM to register + for i in T.vectorized(thread_elem): + local[i] = A[bx, tx * thread_elem + i] + + # 2. Hadamard inside thread, n<=8 + for i in T.serial(thread_round): + chunksize = 1 << (i + 1) + chunknum = thread_elem // chunksize + for j in T.serial(chunknum): + chunkbase = j * chunksize + for k in T.serial(chunksize // 2): + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] + local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] + + # 3. Hadamard inside warp, n<=512 + # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory + another_val = T.alloc_local((thread_elem,), dtype) + + warp_shfl(local, another_val, warp_round) + + # 4. Hadamard inside block, n<=32768 + # Only exchange once for n<=8192, since shared mem can hold all elems + if block_round > 0: + warp_id = tx // warp_size + lane_id = tx % warp_size + src_tx = warp_id * warp_size + lane_id + tgt_warp_id = tx % warps + tgt_lane_id = tx // warps + tgt_tx = tgt_warp_id * warp_size + tgt_lane_id + + # 4.1 Write to smem, swap, read from smem + for cur_round in T.serial(exchange_round): + exchange_base = thread_elem_in_smem * cur_round + for j in T.vectorized(thread_elem_in_smem): + shared[src_tx, j] = local[exchange_base + j] + + for j in T.vectorized(thread_elem_in_smem): + local[exchange_base + j] = shared[tgt_tx, j] + + # 4.2 Warp shuffle + warp_shfl(local, another_val, block_round) + + # 4.3 Write to smem, swap, read from smem + for cur_round in T.serial(exchange_round): + exchange_base = thread_elem_in_smem * cur_round + for j in T.vectorized(thread_elem_in_smem): + shared[tgt_tx, j] = local[exchange_base + j] + + for j in T.vectorized(thread_elem_in_smem): + local[exchange_base + j] = shared[src_tx, j] + + # 5. Write back to HBM + for i in T.vectorized(thread_elem): + B[bx, tx * thread_elem + i] = local[i] + + return main + + +def ref_program(x: torch.Tensor): + assert x.ndim == 2 + dim = x.shape[-1] + assert is_pow_of_2(dim) + return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--dim", type=int, default=32768, help="Dimension") + args = parser.parse_args() + + B, D = args.batch, args.dim + x = torch.randn((B, D), device="cuda") + kernel = hadamard(B, D, T.float32) + y = kernel(x) + y_ref = ref_program(x) + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) + print("All tests passed.") + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency = profiler.do_bench(warmup=100) + print("Tile-lang: {:.2f} ms".format(latency)) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/lazy_jit/lazyjit.en.ipynb b/tilelang/original/examples/lazy_jit/lazyjit.en.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..99cb977f0066b97da63368e5f550ef160bd52f1d --- /dev/null +++ b/tilelang/original/examples/lazy_jit/lazyjit.en.ipynb @@ -0,0 +1,789 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT combines the jit generation and invocation logic.\n", + "\n", + "The function signature syntax is similar to triton but with significant enhancements, most notably allowing Tensor annotations:\n", + "\n", + "For example, the code below annotates a 2D Tensor with T.Tensor[[int, int], T.float16]\n", + "1. Each dimension is a compile-time constant; changing it triggers recompilation\n", + "2. Its dtype must be T.float16\n", + "\n", + "DType can also be Any or None in addition to a concrete type\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm(\n", + " A: T.Tensor[[int, int], T.float16],\n", + " B: T.Tensor[[int, int], T.float16],\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), out_dtype)\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "Call the Tensor directly as an argument to trigger the full jit compile-and-run workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "Change the call-site arguments; if the compiler parameters differ, it recompiles:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "You can also manually call compile helpers to build a kernel\n", + "\n", + "1. `ker.compile` compiles the kernel\n", + "2. `ker.get_tir` retrieves the tir\n", + "3. `ker.par_compile` compiles in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n" + ] + } + ], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### Separate the implementation with macros" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "Next we'll implement a simple gemm in several ways. For convenience, first write a macro that captures the main gemm logic:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### Mark dynamic shapes with T.dyn\n", + "\n", + "When some dimensions are dynamic, mark them with T.dyn. T.dyn can take a string argument to name the variable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_dyn_K(\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "c60fd346", + "metadata": {}, + "source": [ + "Inspect the lazy_jit function signature: parameters with a `$` suffix are compile-time constants that may vary, and those with `$dyn` are runtime variables" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c6992eb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n", + " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gemm_dyn_K.func.annot" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### Use T.StridedTensor to annotate tensors with strides\n", + "\n", + "Annotation format: T.StridedTensor[Shape, Stride, DType]. Each Shape or Stride entry can be\n", + "* int: compile-time constant\n", + "* T.dyn: runtime value\n", + "\n", + "DType can be None or Any" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", + " M, N = A.shape\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A[::2, ::2])\n", + "B_ref = A[::2, ::2].contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### Annotate tensors with T.ptr\n", + "lazy_jit lets you declare a handle with T.ptr, but you must define its shape inside the function via T.match_buffer" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: int,\n", + " N: int,\n", + " K: int,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16)\n", + " B = T.match_buffer(B, (K, N), T.float16)\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### Use T.int32 to annotate runtime variables\n", + "\n", + "lazy_jit lets you define runtime variables with T.int32 or other types, enabling a fully dynamic gemm similar to triton" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr_dyn(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: T.int32,\n", + " N: T.int32,\n", + " K: T.int32,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n", + " B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## Compilation and parallel compilation" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "lazyjit and the original jit both support parallel compilation\n", + "\n", + "To avoid wasting memory with torch.tensor placeholders, use T.Tensor to create placeholders" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6d7f05cdfff412e9a527332438f7aa2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## More convenient macros" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang macros are now upgraded:\n", + "\n", + "1. Allow `T.Ref` as an annotation, similar to C++ pass-by-reference\n", + "2. Allow returning multiple values\n", + "3. Allow nesting and recursion" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### Passing references with T.Ref\n", + "\n", + "The reference via T.Ref can target a var or a buffer element" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # Supports constant indices\n", + " macro_with_ref(x[1])\n", + "\n", + " # Also supports variable indices\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### Pass as arguments\n", + "\n", + "You can pass macros as parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def element_wise(\n", + " A: T.Tensor[[T.dyn], Any],\n", + " fn,\n", + "):\n", + " (N,) = A.shape\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro recursion\n", + "\n", + "Macro can be recursive, even if it's rarely needed, as long as the termination condition is known at compile time" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro returning multiple values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tilelang/original/examples/lazy_jit/lazyjit.zh.ipynb b/tilelang/original/examples/lazy_jit/lazyjit.zh.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..601c5c5d2fe3610adfed8edd890d6a701f5c49f8 --- /dev/null +++ b/tilelang/original/examples/lazy_jit/lazyjit.zh.ipynb @@ -0,0 +1,789 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n", + "\n", + "函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n", + "\n", + "例如,下面的代码用 T.Tensor[[int, int], T.float16] 来标注了一个二维 Tensor\n", + "1. 它的每个维度都是编译期常量,如果改变,会触发重新编译\n", + "2. 它的类型必须是 T.float16\n", + "\n", + "DType 除了写确定的外,还可以写 Any 或者 None" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm(\n", + " A: T.Tensor[[int, int], T.float16],\n", + " B: T.Tensor[[int, int], T.float16],\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), out_dtype)\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "更改调用的参数,如果编译器参数发生了变化,会触发重新编译:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "你也可以手动调用 compile 函数编译 kernel\n", + "\n", + "1. `ker.compile` 编译 kernel\n", + "2. `ker.get_tir` 获取 tir\n", + "3. `ker.par_compile` 并行编译" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n" + ] + } + ], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### 用 macro 来分离实现" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### 用 T.dyn 标记动态 Shape\n", + "\n", + "当某些维度是动态的的时候,可以用 T.dyn 来标记。T.dyn 可以接受一个字符串参数,表示变量的名字" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_dyn_K(\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "c60fd346", + "metadata": {}, + "source": [ + "查看 lazy_jit 的函数签名,其中带有后缀`$` 的是不确定的编译期常量,带有 `$dyn` 的是运行时的变量" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c6992eb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n", + " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gemm_dyn_K.func.annot" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### 用 T.StridedTensor 标记带 stride 的 Tensor\n", + "\n", + "标记方法:T.StridedTensor[Shape, Stride, DType],每个 Shape 或 Stride 可以写\n", + "* int: 表示编译期常量\n", + "* T.dyn:表示运行时常量\n", + "\n", + "DType 可以写 None 或 Any" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", + " M, N = A.shape\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A[::2, ::2])\n", + "B_ref = A[::2, ::2].contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### 用 T.ptr 标注 Tensor\n", + "lazy_jit 允许你用 T.ptr 来声明一个 handle,但必须在函数内用 T.match_buffer 给它定义 shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: int,\n", + " N: int,\n", + " K: int,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16)\n", + " B = T.match_buffer(B, (K, N), T.float16)\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### 用 T.int32 标注运行时变量\n", + "\n", + "lazy_jit 允许你用 T.int32 或其他类型来定义运行时变量,这样,你可以写一个完全动态的 gemm,这和 triton 非常相似" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr_dyn(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: T.int32,\n", + " N: T.int32,\n", + " K: T.int32,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n", + " B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## 编译与并行编译" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "lazyjit 和原来的 jit 都支持并行编译\n", + "\n", + "为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6d7f05cdfff412e9a527332438f7aa2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## 更便利的 Macro" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang 的 macro 现在已经升级:\n", + "\n", + "1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n", + "2. 允许返回多个值\n", + "3. 允许嵌套,递归" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### T.Ref 传递引用\n", + "\n", + "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # 支持常量 index\n", + " macro_with_ref(x[1])\n", + "\n", + " # 也支持变量 index\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### 当作参数传递\n", + "\n", + "你可以把 macro 当做参数传递" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def element_wise(\n", + " A: T.Tensor[[T.dyn], Any],\n", + " fn,\n", + "):\n", + " (N,) = A.shape\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro 递归\n", + "\n", + "虽然不知道有没有这种需求,但 macro 是可以递归的,但要求终止条件编译期间确定" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro 返回多个值" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tilelang/original/examples/linear_attention/README.md b/tilelang/original/examples/linear_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..92b10692b32a8c9f1aa5ed979510acd5321f84e4 --- /dev/null +++ b/tilelang/original/examples/linear_attention/README.md @@ -0,0 +1 @@ +# Linear Attention diff --git a/tilelang/original/examples/linear_attention/example_linear_attn_bwd.py b/tilelang/original/examples/linear_attention/example_linear_attn_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..397ec7bdf6fe6d27a47e3d29c82d1b331ebb277d --- /dev/null +++ b/tilelang/original/examples/linear_attention/example_linear_attn_bwd.py @@ -0,0 +1,203 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA +from fla.modules.l2norm import l2norm_fwd +from einops import rearrange +from typing import Optional, Tuple + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def tl_fused_chunk_bwd_kernel( + B, + S, + H, + DK, + DV, + dtype: T.dtype = T.float16, + scale: float = None, +) -> torch.Tensor: + if scale is None: + scale = DK**-0.5 + accum_dtype = T.float32 + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tilelang.cdiv(DK, BK) + NV = tilelang.cdiv(DV, BV) + NT = tilelang.cdiv(S, chunk_size) + + @T.prim_func + def fused_chunk_linear_attn_bwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + ): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + + ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + dq = T.alloc_fragment([chunk_size, BK], accum_dtype) + dq_shared = T.alloc_shared([chunk_size, BK], accum_dtype) + dk = T.alloc_fragment([chunk_size, BK], accum_dtype) + dk_shared = T.alloc_shared([chunk_size, BK], accum_dtype) + dv = T.alloc_fragment([chunk_size, BV], accum_dtype) + dv_shared = T.alloc_shared([chunk_size, BV], accum_dtype) + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + do = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BV, BK], accum_dtype) + h_shared = T.alloc_shared([BV, BK], dtype) + dh = T.alloc_fragment([BK, BV], accum_dtype) + dh_shared = T.alloc_shared([BK, BV], dtype) + + T.annotate_layout( + { + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) + T.use_swizzle(10) + + T.clear(h) + T.clear(dh) + + # Calculate dQ + for i in T.Pipelined(0, NT): + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) + + T.gemm(do, v, ds, transpose_B=True, clear_accum=True) + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else(row >= col, ds[row, col], 0) + + T.gemm(ds_shared, k, dq, clear_accum=True) + T.copy(h, h_shared) + T.gemm(do, h_shared, dq) + T.gemm(v, k, h, transpose_A=True) + for row, col in T.Parallel(chunk_size, BK): + dq[row, col] *= scale + T.copy(dq, dq_shared) + T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared) + + # Calculate dK, dV (reversely) + for i in T.Pipelined(1, NT + 1): + start = NT - i + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) + + # Calculate dk + T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) + T.gemm(ds_shared, q, dk, clear_accum=True) + T.copy(dh, dh_shared) + T.gemm(v, dh_shared, dk, transpose_B=True) + + # Calculate dv + T.gemm(k, q, ds, transpose_B=True, clear_accum=True) + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) + T.gemm(ds_shared, do, dv, clear_accum=True) + T.gemm(k, dh_shared, dv) + + # Update dh + T.gemm(q, do, dh, transpose_A=True) + + T.copy(dk, dk_shared) + T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared) + T.copy(dv, dv_shared) + T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared) + + return fused_chunk_linear_attn_bwd + + +def tl_fused_chunk_bwd(Q, K, V, dO): + B, S, H, D = Q.shape + kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros_like(K, dtype=torch.float32) + dV = torch.zeros_like(V, dtype=torch.float32) + kernel(Q, K, V, dO, dQ, dK, dV) + return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16) + + +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: + q, k, v = q.float(), k.float(), v.float() + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + h = kv[:, :, -1, :, :] + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v + o = inter + intra + return rearrange(o, "b h n c d -> b (n c) h d"), h + + +def main(B=1, S=1024, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + + # qk norm is necessary for linear attn + q = l2norm_fwd(q)[0].requires_grad_(True) + k = l2norm_fwd(k)[0].requires_grad_(True) + + dq, dk, dv = tl_fused_chunk_bwd(q, k, v, do) + q.grad = k.grad = v.grad = None + o_ref, _ = ref_program(q, k, v) + o_ref.backward(do, retain_graph=True) + + assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}" + assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}" + assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}" + print("Passed all tests!✅") + + # Benchmark + q.grad = k.grad = v.grad = None + o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") + args = parser.parse_args() + + main(args.B, args.S, args.H, args.D) diff --git a/tilelang/original/examples/linear_attention/example_linear_attn_fwd.py b/tilelang/original/examples/linear_attention/example_linear_attn_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..849841e5179e3a3baccf1bb7794e425cd5fee990 --- /dev/null +++ b/tilelang/original/examples/linear_attention/example_linear_attn_fwd.py @@ -0,0 +1,149 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA +from fla.modules.l2norm import l2norm_fwd +from einops import rearrange +from typing import Optional, Tuple + + +@tilelang.jit( + out_idx=[4], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def tl_fused_chunk_fwd_kernel( + B, + S, + H, + DK, + DV, + dtype: T.dtype = T.float16, + scale: float = None, +) -> torch.Tensor: + if scale is None: + scale = DK**-0.5 + accum_dtype = T.float32 + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tilelang.cdiv(DK, BK) + NV = tilelang.cdiv(DV, BV) + NT = tilelang.cdiv(S, chunk_size) + + @T.prim_func + def fused_chunk_linear_attn_fwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype), + ): # type: ignore + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) + h_shared = T.alloc_shared([BK, BV], dtype) + s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + o = T.alloc_fragment([chunk_size, BV], accum_dtype) + o_shared = T.alloc_shared([chunk_size, BV], accum_dtype) + + T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)}) + T.use_swizzle(10) + + T.clear(h) + + for i in T.Pipelined(0, NT): + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + + T.gemm(q, k, s, clear_accum=True, transpose_B=True) + for row, col in T.Parallel(chunk_size, chunk_size): + s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0) + + T.gemm(s_shared, v, o, clear_accum=True) + T.copy(h, h_shared) + T.gemm(k, v, h, transpose_A=True) + T.gemm(q, h_shared, o) + T.copy(o, o_shared) + T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared) + + # Output final state + T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV]) + + return fused_chunk_linear_attn_fwd + + +def tl_fused_chunk_fwd(q, k, v): + B, S, H, D = q.shape + kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) + print(kernel.get_kernel_source()) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) + h = kernel(q, k, v, o) + return o, h + + +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: + q, k, v = q.float(), k.float(), v.float() + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + h = kv[:, :, -1, :, :] + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v + o = inter + intra + return rearrange(o, "b h n c d -> b (n c) h d"), h + + +def main(B=1, S=512, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + + # qk norm is necessary for linear attn + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + o, h = tl_fused_chunk_fwd(q, k, v) + o_ref, h_ref = ref_program(q, k, v) + + assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}" + assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}" + print("Passed all tests!✅") + + t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") + args = parser.parse_args() + + main(args.B, args.S, args.H, args.D) diff --git a/tilelang/original/examples/linear_attention/example_mamba_chunk_scan.py b/tilelang/original/examples/linear_attention/example_mamba_chunk_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..1958dfb5aa95a64fd38e40f6632b787b39150c19 --- /dev/null +++ b/tilelang/original/examples/linear_attention/example_mamba_chunk_scan.py @@ -0,0 +1,285 @@ +import argparse +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, repeat +import itertools + + +def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd + + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) + return out + + +def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + _, _, ngroups, _, _ = cb.shape + batch, seqlen, nheads, headdim = x.shape + # _, _, ngroups, dstate = B.shape + # assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + # assert C.shape == B.shape + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups) + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out + + +def get_configs(): + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): + dtype = T.float16 + accum_dtype = T.float32 + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + + @T.prim_func + def main( + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + ): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_local = T.alloc_fragment((block_M, block_K), dtype) + dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_local = T.alloc_fragment((block_K), accum_dtype) + x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") + dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + scale_m_local = T.alloc_fragment((block_M), accum_dtype) + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) + D_local = T.alloc_fragment((1), accum_dtype) + x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + batch_idx = by % batch + chunk_idx = by // batch + # m: chunk_size + # n : headdim + m_idx = bx // T.ceildiv(headdim, block_N) + n_idx = bx % T.ceildiv(headdim, block_N) + + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) + + T.no_set_max_nreg() + + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) + T.copy(dA_cs_m_shared, dA_cs_m_local) + T.clear(acc_o) + + for i in T.Parallel(block_M): + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) + T.copy( + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] *= scale_m_local[i] + + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) + T.copy(cb_shared, cb_local) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) + T.copy(dA_cs_k_shared, dA_cs_k_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] *= dt_local[j] + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) + T.gemm(cb_local, x_shared, acc_o) + + D_local[0] = D[bz] + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) + T.copy(x_residual_shared, x_residual_local) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] += x_residual_local[i, j] * D_local[0] + + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) + + return main + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) + total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate + + if not args.tune: + kernel = chunk_scan_fwd( + batch, + seq_len, + chunk_size, + groups, + heads, + dim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/tilelang/original/examples/linear_attention/example_mamba_chunk_state.py b/tilelang/original/examples/linear_attention/example_mamba_chunk_state.py new file mode 100644 index 0000000000000000000000000000000000000000..fb766d5e9c9c7d4b7e4de1f1f15f801f84b9de03 --- /dev/null +++ b/tilelang/original/examples/linear_attention/example_mamba_chunk_state.py @@ -0,0 +1,178 @@ +import argparse +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, repeat +import itertools + + +def chunk_state_triton(B, x, dt, dA_cumsum): + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd + + return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) + + +def ref_program(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) + + +def get_configs(): + iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[4]) +def chunk_state_fwd( + batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128 +): + dtype = T.float16 + accum_dtype = T.float32 + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + + @T.prim_func + def main( + B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), + ): + with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by): + x_shared = T.alloc_shared((block_K, block_M), dtype) + x_local = T.alloc_fragment((block_K, block_M), dtype) + xt_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + dt_shared = T.alloc_shared((block_K), dtype) + dA_cumsum_shared = T.alloc_shared((block_K), dtype) + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + scale = T.alloc_fragment((block_K), accum_dtype) + dA_cs_last = T.alloc_fragment((1), accum_dtype) + dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype) + dt_local = T.alloc_fragment((block_K), accum_dtype) + + loop_range = T.ceildiv(chunk_size, block_K) + + batch_idx = by % batch + chunk_idx = by // batch + m_idx = bx // T.ceildiv(dstate, block_N) + n_idx = bx % T.ceildiv(dstate, block_N) + + T.annotate_layout( + {x_shared: tilelang.layout.make_swizzled_layout(x_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)} + ) + + dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] + T.clear(acc_o) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + m_idx * block_M : (m_idx + 1) * block_M, + ], + x_shared, + ) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dA_cumsum_shared, dA_cumsum_local) + T.copy(dt_shared, dt_local) + for i in T.Parallel(block_K): + scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i] + T.copy(x_shared, x_local) + for i, j in T.Parallel(block_M, block_K): + xt_local[i, j] = x_local[j, i] * scale[j] + T.copy( + B[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz // (nheads // ngroups), + n_idx * block_N : (n_idx + 1) * block_N, + ], + B_shared, + ) + T.gemm(xt_local, B_shared, acc_o) + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N], + ) + + return main + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) + total_flops = 2 * batch * seq_len * heads * dim * dstate + + if not args.tune: + kernel = chunk_state_fwd( + batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/tilelang/original/examples/linear_attention/example_retention_fwd.py b/tilelang/original/examples/linear_attention/example_retention_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f45e383889bd7ef2d93e1a00539e72110465ea43 --- /dev/null +++ b/tilelang/original/examples/linear_attention/example_retention_fwd.py @@ -0,0 +1,107 @@ +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench + +import argparse + + +@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def chunk_retention_fwd_kernel( + B, + S, + H, + DK, + DV, + dtype: T.dtype = T.float16, + scale: float = None, +) -> torch.Tensor: + if scale is None: + scale = DK**-0.5 + accum_dtype = T.float32 + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tl.cdiv(DK, BK) + NV = tl.cdiv(DV, BV) + NT = tl.cdiv(S, chunk_size) + + @T.prim_func + def chunk_retention_fwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + ): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + log_decay = T.alloc_var(T.float32) + log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay + + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) + h_shared = T.alloc_shared([BK, BV], dtype) + s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + o = T.alloc_fragment([chunk_size, BV], accum_dtype) + T.clear(h) + + T.use_swizzle(10) + + for i in T.Pipelined(0, NT): + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + + T.gemm(q, k, s, clear_accum=True, transpose_B=True) + for row, col in T.Parallel(chunk_size, chunk_size): + s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0) + + T.copy(h, h_shared) + T.gemm(q, h_shared, o, clear_accum=True) + for row, col in T.Parallel(chunk_size, BV): + o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col] + T.gemm(s_shared, v, o) + + for row, col in T.Parallel(chunk_size, BV): + v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) + for row, col in T.Parallel(BK, BV): + h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] + T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV]) + T.gemm(k, v, h, transpose_A=True) + + return chunk_retention_fwd + + +def postprocess(o): + return o if o.size(0) == 1 else o.sum(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=4096, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") + args = parser.parse_args() + B, S, H, D = args.B, args.S, args.H, args.D + total_flops = 2.0 * B * S * S * H * D # causal + + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + + kernel = chunk_retention_fwd_kernel(B, S, H, D, D) + + t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) + print(f"Tilelang latency: {t:.3f} ms") + print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/linear_attention/test_linear_attn.py b/tilelang/original/examples/linear_attention/test_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..346fa8e96ed983bb12f8bb0845611203d536392a --- /dev/null +++ b/tilelang/original/examples/linear_attention/test_linear_attn.py @@ -0,0 +1,18 @@ +import tilelang.testing + +import example_linear_attn_fwd +import example_linear_attn_bwd + + +@tilelang.testing.requires_cuda +def test_example_linear_attn_fwd(): + example_linear_attn_fwd.main() + + +@tilelang.testing.requires_cuda +def test_example_linear_attn_bwd(): + example_linear_attn_bwd.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/minference/README.md b/tilelang/original/examples/minference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8cba732609e59295453a32a846a21fbd73bcb3cd --- /dev/null +++ b/tilelang/original/examples/minference/README.md @@ -0,0 +1,28 @@ +# Performance Benchmark + +## Hardware & Environment +- **Hardware**: NVIDIA H100 PCIe +- **CUDA version**: 12.8.1 +- **PyTorch Version**: 2.7.1+cu128 +- **Triton Version**: 3.3.1 + +## Performance Results +BATCH_SIZE=1, HEAD=1, DIM=64 + +| SEQ_LEN | VS_LIST | Triton Time | TileLang Time | Speedup | +|---------|--------------|-------------|---------------|---------| +| 8192 | [1000, 200] | 0.168 ms | 0.105 ms | 1.60x | +| 8192 | [1000, 600] | 0.207 ms | 0.119 ms | 1.74x | +| 8192 | [800, 600] | 0.207 ms | 0.122 ms | 1.70x | +| | | | | | +| 16384 | [1000, 200] | 0.261 ms | 0.167 ms | 1.56x | +| 16384 | [1000, 600] | 0.419 ms | 0.258 ms | 1.62x | +| 16384 | [800, 600] | 0.422 ms | 0.255 ms | 1.65x | +| | | | | | +| 32768 | [1000, 200] | 0.374 ms | 0.248 ms | 1.51x | +| 32768 | [1000, 600] | 0.823 ms | 0.554 ms | 1.49x | +| 32768 | [800, 600] | 0.826 ms | 0.558 ms | 1.48x | +| | | | | | +| 65536 | [1000, 200] | 0.637 ms | 0.524 ms | 1.22x | +| 65536 | [1000, 600] | 1.758 ms | 1.501 ms | 1.17x | +| 65536 | [800, 600] | 1.783 ms | 1.489 ms | 1.20x | diff --git a/tilelang/original/examples/minference/example_vertical_slash_sparse_attn.py b/tilelang/original/examples/minference/example_vertical_slash_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f96e73ae511d300e3fa3569ef6910805ea19bca6 --- /dev/null +++ b/tilelang/original/examples/minference/example_vertical_slash_sparse_attn.py @@ -0,0 +1,623 @@ +# Copyright (c) 2024-2025 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math +import argparse + +import torch +import triton +import triton.language as tl + +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench + + +@tilelang.jit(out_idx=[3]) +def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): + block_M = 64 + block_N = 64 + num_stages = 2 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 + shape = [batch, heads, seq_len, dim] + + seq_blocks = (seq_len + block_M - 1) // block_M + + count_shape = [batch, heads, seq_blocks] + + offset_shape = count_shape + [slash_size] + index_shape = count_shape + [vertical_size] + + vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size) + + dtype = T.float16 + accum_dtype = T.float32 + int_dtype = T.int32 + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def Prefetch( + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + column_index: T.SharedBuffer([vertical_size_round], int_dtype), + column_count: T.int32, + k: T.int32, + bz: T.int32, + by: T.int32, + ): + with T.attr("default", "async_scope", 1): + for i, j in T.Parallel(block_N, dim): + K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0) + + with T.attr("default", "async_scope", 1): + for i, j in T.Parallel(block_N, dim): + V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0) + + T.ptx_commit_group() + + @T.macro + def Compute( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + k: T.int32, + column_count: T.int32, + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + count: T.int32, + ): + T.ptx_wait_group(count) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k + j < column_count, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] = acc_o[i, j] * scores_scale[i] + + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + @T.prim_func + def vs_sparse_flashattn_ws( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + BlockCount: T.Tensor(count_shape, int_dtype), + BlockOffset: T.Tensor(offset_shape, int_dtype), + ColumnCount: T.Tensor(count_shape, int_dtype), + ColumnIndex: T.Tensor(index_shape, int_dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz): + bx = T.ceildiv(seq_len, block_M) - 1 - bc + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([2, block_N, dim], dtype) + V_shared = T.alloc_shared([2, block_N, dim], dtype) + K_shared_1 = T.alloc_shared([block_N, dim], dtype) + V_shared_1 = T.alloc_shared([block_N, dim], dtype) + K_shared_2 = T.alloc_shared([block_N, dim], dtype) + V_shared_2 = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_count = T.alloc_local([1], int_dtype) + block_offset = T.alloc_shared([slash_size_round], int_dtype, scope="shared") + column_count = T.alloc_local([1], int_dtype) + column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared") + + T.create_list_of_mbarrier([128] * 9) + + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + block_count[0] = BlockCount[bz, by, bx] + column_count[0] = ColumnCount[bz, by, bx] + + for vi in T.Parallel(slash_size_round): + if vi < slash_size: + block_offset[vi] = BlockOffset[bz, by, bx, vi] + + for vi in T.Parallel(vertical_size_round): + if vi < vertical_size: + column_index[vi] = ColumnIndex[bz, by, bx, vi] + + tid = T.get_thread_binding() + + if tid >= 128: + T.annotate_producer_reg_dealloc() + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.mbarrier_arrive(mbarrier=8) + for bi in T.serial(block_count[0]): + k = block_offset[bi] + T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) + T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) + T.mbarrier_arrive(mbarrier=bi % 2) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1)) + T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :]) + T.mbarrier_arrive(mbarrier=bi % 2 + 2) + else: + T.annotate_consumer_reg_alloc() + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.mbarrier_wait_parity(mbarrier=8, parity=0) + for bi in T.serial(block_count[0]): + k = block_offset[bi] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype)) + + T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1)) + T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.mbarrier_arrive(mbarrier=bi % 2 + 4) + + T.copy(scores_max, scores_max_prev) + + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] = acc_o[i, j] * scores_scale[i] + + T.copy(acc_s, acc_s_cast) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1)) + T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.mbarrier_arrive(mbarrier=bi % 2 + 6) + + T.reduce_sum(acc_s, scores_sum, dim=1) + + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + if column_count[0] != 0: + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by) + for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): + k = bi * block_N + if bi % 2 == 0: + Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 1, + ) + else: + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 1, + ) + if T.ceildiv(column_count[0], block_N) % 2 == 0: + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 0, + ) + else: + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 0, + ) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return vs_sparse_flashattn_ws + + return kernel_func(block_M, block_N, num_stages, threads) + + +@triton.jit +def _triton_mixed_sparse_attn_fwd_kernel( + Q, + K, + V, + seqlens, + sm_scale, + block_count, + block_offset, + column_count, + column_index, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_ok, + Z, + H, + N_CTX, + NUM_ROWS, + NNZ_S, + NNZ_V, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, +): + start_m = tl.program_id(0) # bx + off_hz = tl.program_id(1) # by + + seqlen = tl.load(seqlens + off_hz // H) + if start_m * BLOCK_M >= seqlen: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh + kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + + num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m) + blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S + num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m) + cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(dtype) + + # loop over k, v and update accumulator + m_mask = offs_m[:, None] < seqlen + + for block_index in range(num_blks): + start_n = tl.load(blks_ptr + block_index) + cols = start_n + offs_n + n_mask = cols < seqlen + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0) + v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0) + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + causal_mask = cols[None, :] <= offs_m[:, None] + qk = tl.where(m_mask & causal_mask, qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(dtype), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + for start_n in range(0, num_cols, BLOCK_N): # + # bi * BLOCK_N: bi * BLOCK_N + BLOCK_N + n_mask = start_n + offs_n < num_cols + cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=0) + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0) + v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0) + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(m_mask & n_mask, qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(dtype), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O + acc /= l_i[:, None] + # acc = tl.where(m_mask, acc / l_i[:, None], 0.0) + tl.store(o_ptrs, acc.to(dtype), mask=m_mask) + + +def _triton_mixed_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlens: torch.Tensor, + block_count: torch.Tensor, + block_offset: torch.Tensor, + column_count: torch.Tensor, + column_index: torch.Tensor, + sm_scale: float, + block_size_M: int = 64, + block_size_N: int = 64, +) -> torch.Tensor: + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.zeros_like(q) + grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1) + dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16 + _triton_mixed_sparse_attn_fwd_kernel[grid]( + q, + k, + v, + seqlens, + sm_scale, + block_count, + block_offset, + column_count, + column_index, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + block_count.shape[-1], + block_offset.shape[-1], + column_index.shape[-1], + BLOCK_M=block_size_M, + BLOCK_N=block_size_N, + BLOCK_DMODEL=Lk, + dtype=dtype, + num_warps=4, + num_stages=2, + ) + + return o + + +def vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + block_size_M: int = 64, + block_size_N: int = 64, +): + from torch.utils.cpp_extension import load + import os + + current_dir = os.path.dirname(os.path.abspath(__file__)) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) + convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes + batch_size, num_heads, context_size, head_dim = query.shape + pad = (block_size_M - context_size) & (block_size_M - 1) + if pad == block_size_M: + pad = 0 + query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim + query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + + seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) + sm_scale = head_dim**-0.5 + block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( + seqlens, + v_idx, + s_idx, + context_size, + block_size_M, + block_size_N, + ) + + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2]) + + def run(is_triton: bool = True): + if is_triton: + out = _triton_mixed_sparse_attention( + query, + key, + value, + seqlens, + block_count, + block_offset, + column_count, + column_index, + sm_scale, + block_size_M, + block_size_N, + ) + else: + out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) + return out[..., :context_size, :head_dim] + + return run + + +def sum_all_diagonal_matrix(mat: torch.tensor): + b, h, n, m = mat.shape + zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right + mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides + sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns + return sum_diags[:, :, 1:] + + +def main(argv=None): + parser = argparse.ArgumentParser() + + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--vertical_size", type=int, default=1000) + parser.add_argument("--slash_size", type=int, default=200) + + args = parser.parse_args(argv) + + BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim + + vertical_size, slash_size = args.vertical_size, args.slash_size + + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + q_len = SEQ_LEN + + vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) + last_q = 64 + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) + arange = torch.arange(last_q, device="cuda") + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) + qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + vertical_topk = torch.topk(vertical, vertical_size, -1).indices + + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] + slash[..., -30:] = torch.inf + + slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices + + _attn = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) + + tilelang_out = _attn(False) + triton_out = _attn(True) + + torch.testing.assert_close(triton_out, tilelang_out, atol=1e-2, rtol=1e-2) + + triton_time = do_bench(lambda: _attn(True)) + tilelang_time = do_bench(lambda: _attn(False)) + + print(f"triton_time: {triton_time:.3f}ms") + print(f"tilelang_time: {tilelang_time:.3f}ms") + print(f"speedup: {triton_time / tilelang_time:.2f}x") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/minference/ops/kernels.cpp b/tilelang/original/examples/minference/ops/kernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f1e33976447b6dbc8ae1f591af2dc27851f0e0d --- /dev/null +++ b/tilelang/original/examples/minference/ops/kernels.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "torch/extension.h" +#include + +std::vector convert_vertical_slash_indexes( + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, int block_size_M, int block_size_N); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("convert_vertical_slash_indexes", &convert_vertical_slash_indexes, + "dynamic sparse index function"); +} diff --git a/tilelang/original/examples/minference/ops/vertical_slash_index.cu b/tilelang/original/examples/minference/ops/vertical_slash_index.cu new file mode 100644 index 0000000000000000000000000000000000000000..ae01f331b1ab284e6d646aa072e7ab61bb5b3d0a --- /dev/null +++ b/tilelang/original/examples/minference/ops/vertical_slash_index.cu @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include +#include +#include +#include + +#include + +__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int N_HEADS, + int N_ROWS, + int BLOCK_SIZE_M, + int BLOCK_SIZE_N, + int NNZ_V, + int NNZ_S +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int seqlen = seqlens[batch_idx]; + int block_idx_m = group_idx * blockDim.x + threadIdx.x; + int start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= seqlen) { + return; + } + int end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + int tmp_col_cnt = 0, tmp_blk_cnt = 0; + int s = 0, v = 0; + int v_idx = vertical_indexes[v++]; + int s_idx = slash_indexes[s++]; + while (s_idx >= end_m) { + s_idx = slash_indexes[s++]; + } + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + v_idx = end_m + BLOCK_SIZE_M; + } + } else { + if (s < NNZ_S) { + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + break; + } + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int BATCH_SIZE, + int N_HEADS, + int N_ROWS, + int NNZ_V, + int NNZ_S +) { + const int BLOCK_SIZE_M = 64; + const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + seqlens, vertical_indexes, slash_indexes, + block_count, block_offset, column_count, column_index, + N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S + ); +} + +std::vector convert_vertical_slash_indexes( + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, + int block_size_M, + int block_size_N +) { + assert(block_size_M == 64); + assert(block_size_N == 64); + + cudaSetDevice(seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); + torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); + + convert_vertical_slash_indexes_64x64( + seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + nnz_vertical, + nnz_slash + ); + + return { block_count, block_offset, column_count, column_index }; +} diff --git a/tilelang/original/examples/minference/ops/vertical_slash_index.hip b/tilelang/original/examples/minference/ops/vertical_slash_index.hip new file mode 100644 index 0000000000000000000000000000000000000000..f01fd421125d6ccde89bb402c2cd9a30cb1cec20 --- /dev/null +++ b/tilelang/original/examples/minference/ops/vertical_slash_index.hip @@ -0,0 +1,161 @@ +// !!! This is a file automatically generated by hipify!!! +#include +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include +#include +#include +#include + +#include + +__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int N_HEADS, + int N_ROWS, + int BLOCK_SIZE_M, + int BLOCK_SIZE_N, + int NNZ_V, + int NNZ_S +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int seqlen = seqlens[batch_idx]; + int block_idx_m = group_idx * blockDim.x + threadIdx.x; + int start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= seqlen) { + return; + } + int end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + int tmp_col_cnt = 0, tmp_blk_cnt = 0; + int s = 0, v = 0; + int v_idx = vertical_indexes[v++]; + int s_idx = slash_indexes[s++]; + while (s_idx >= end_m) { + s_idx = slash_indexes[s++]; + } + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + v_idx = end_m + BLOCK_SIZE_M; + } + } else { + if (s < NNZ_S) { + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + break; + } + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int BATCH_SIZE, + int N_HEADS, + int N_ROWS, + int NNZ_V, + int NNZ_S +) { + const int BLOCK_SIZE_M = 64; + const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0, + seqlens, vertical_indexes, slash_indexes, + block_count, block_offset, column_count, column_index, + N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S + ); +} + +std::vector convert_vertical_slash_indexes( + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, + int block_size_M, + int block_size_N +) { + assert(block_size_M == 64); + assert(block_size_N == 64); + + hipSetDevice(seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); + torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); + + convert_vertical_slash_indexes_64x64( + seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + nnz_vertical, + nnz_slash + ); + + return { block_count, block_offset, column_count, column_index }; +} diff --git a/tilelang/original/examples/minference/test_vs_sparse_attn.py b/tilelang/original/examples/minference/test_vs_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f01df3808f86fb17d41a7b3617f76c949534fa45 --- /dev/null +++ b/tilelang/original/examples/minference/test_vs_sparse_attn.py @@ -0,0 +1,12 @@ +import tilelang.testing + +import example_vertical_slash_sparse_attn + + +@tilelang.testing.requires_cuda +def test_vs_sparse_attn(): + example_vertical_slash_sparse_attn.main(argv=[]) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/norm/rms_norm.py b/tilelang/original/examples/norm/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..57bccc1a0f901ddb2ac84b0b0e2ac8f92c95480a --- /dev/null +++ b/tilelang/original/examples/norm/rms_norm.py @@ -0,0 +1,76 @@ +import torch +import tilelang +import tilelang.language as T + + +def rms_norm_splitk(M, N, blk_m, blk_k): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, blk_k), dtype) + A_local = T.alloc_fragment((blk_m, blk_k), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + num_k_step = T.ceildiv(N, blk_k) + T.clear(A_local) + for k in range(num_k_step): + T.copy(A[bx * blk_m, k * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_local[i, j] += A_shared[i, j] * A_shared[i, j] + T.reduce_sum(A_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + + for k in range(num_k_step): + # reverse, better cache hit rate + T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_shared[i, j] *= A_powsum[i] + T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) + + return main + + +@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) +def rms_norm(M, N, blk_m): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, N), dtype) + A_pow_local = T.alloc_fragment((blk_m, N), dtype) + A_local = T.alloc_fragment((blk_m, N), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) + T.copy(A_shared, A_local) + for i, j in T.Parallel(blk_m, N): + A_pow_local[i, j] = A_local[i, j] * A_local[i, j] + T.reduce_sum(A_pow_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + for i, j in T.Parallel(blk_m, N): + A_local[i, j] *= A_powsum[i] + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) + + return main + + +def ref_program(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12) + + +if __name__ == "__main__": + M, N, blk_m, blk_k = 8192, 8192, 1, 512 + kernel = rms_norm(M, N, blk_m) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) diff --git a/tilelang/original/examples/norm/test_rms_norm.py b/tilelang/original/examples/norm/test_rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..53db03d98ce32ee034a1d36b9ee590c0992267ce --- /dev/null +++ b/tilelang/original/examples/norm/test_rms_norm.py @@ -0,0 +1,74 @@ +import torch +import tilelang +import tilelang.testing +import tilelang.language as T + + +def rms_norm_splitk(M, N, blk_m, blk_k): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, blk_k), dtype) + A_local = T.alloc_fragment((blk_m, blk_k), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + num_k_step = T.ceildiv(N, blk_k) + T.clear(A_local) + for k in range(num_k_step): + T.copy(A[bx * blk_m, k * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_local[i, j] += A_shared[i, j] * A_shared[i, j] + T.reduce_sum(A_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + + for k in range(num_k_step): + # reverse, better cache hit rate + T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_shared[i, j] *= A_powsum[i] + T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) + + return main + + +def rms_norm(M, N, blk_m): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, N), dtype) + A_pow_local = T.alloc_fragment((blk_m, N), dtype) + A_local = T.alloc_fragment((blk_m, N), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) + T.copy(A_shared, A_local) + for i, j in T.Parallel(blk_m, N): + A_pow_local[i, j] = A_local[i, j] * A_local[i, j] + T.reduce_sum(A_pow_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + for i, j in T.Parallel(blk_m, N): + A_local[i, j] *= A_powsum[i] + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) + + return main + + +def ref_program(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12) + + +def test_rms_norm(M=1024, N=1024, blk_m=1): + program = rms_norm(M, N, blk_m) + kernel = tilelang.compile(program, out_idx=-1, pass_configs={"tl.disable_tma_lower": True}) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/online_softmax/online_softmax.py b/tilelang/original/examples/online_softmax/online_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..811870e441b3f297c711a04179bde5cce925f126 --- /dev/null +++ b/tilelang/original/examples/online_softmax/online_softmax.py @@ -0,0 +1,72 @@ +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench +from typing import Callable + + +@tl.jit(out_idx=[1]) +def softmax_kernel( + M, + N, + dtype: T.dtype = T.float16, +) -> "Callable": + BN = min(tl.next_power_of_2(N), 8192) + NN = tl.cdiv(N, BN) + + accum_dtype = T.float32 + + scale = 1.44269504 # log2(e) + + @T.prim_func + def main( + X: T.Tensor([M, N], dtype), + Y: T.Tensor([M, N], dtype), + ): + with T.Kernel(M, threads=128) as (i_m): + x = T.alloc_fragment([BN], dtype) + y = T.alloc_fragment([BN], dtype) + lse = T.alloc_fragment([1], accum_dtype) + max_x = T.alloc_fragment([1], dtype) + exp_x = T.alloc_fragment([BN], accum_dtype) + sum_exp_x = T.alloc_fragment([1], accum_dtype) + T.fill(lse, -T.infinity(accum_dtype)) + + for i_n in T.Pipelined(0, NN): + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) + + T.reduce_max(x, max_x, dim=0, clear=True) + + for j in T.Parallel(BN): + exp_x[j] = T.exp2(x[j] * scale - max_x[0] * scale) + + T.reduce_sum(exp_x, sum_exp_x, dim=0, clear=True) + + lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0]) + + for i_n in T.Pipelined(0, NN): + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) + + for j in T.Parallel(BN): + y[j] = T.exp2(x[j] * scale - lse[0]) + + T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN]) + + return main + + +M = 8192 +N = 8192 +kernel = softmax_kernel(M, N) +dtype = torch.float16 +X = torch.randn(M, N, dtype=dtype, device="cuda") +Y = kernel(X) +Y_ref = X.softmax(dim=1) + +torch.testing.assert_close(Y, Y_ref, rtol=1e-2, atol=1e-2) + +t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100) +t2 = do_bench(lambda: kernel(X), warmup=25, rep=100) +print(f"torch latency: {t1:.3f} ms") +print(f"TileLang latency: {t2:.3f} ms") +print(f"Speedup: {t1 / t2:.3f}x") diff --git a/tilelang/original/examples/plot_layout/README.md b/tilelang/original/examples/plot_layout/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8204e93d804edde4a1d9bbb00366f7c7be39dae1 --- /dev/null +++ b/tilelang/original/examples/plot_layout/README.md @@ -0,0 +1,108 @@ +The following example demonstrates how to generate and visualize a memory layout using tilelang tools `plot_layout`. + +Example Code + +```python +import tilelang.language as T +from tvm import DataType +from tvm.tir import IndexMap +from typing import Literal, Callable +from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.tools import plot_layout + +def make_mma_load_base_layout(dtype: str = T.float16, + matrix: Literal["A", "B"] = "A", + transposed: bool = False) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.intrinsics.mma_layout import ( + shared_16x16_to_mma_32x8_layout_sr, + shared_16x16_to_mma_32x8_layout_rs, + shared_16x32_to_mma_32x16_layout, + shared_32x16_to_mma_32x16_layout, + ) + assert matrix in ["A", "B"], "matrix should be either A or B" + dtype_bits = DataType(dtype).bits + assert transposed is False, "transposed is not supported yet" + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr: Callable = None + transform_func_rs: Callable = None + if dtype_bits == 16: + transform_func_sr = shared_16x16_to_mma_32x8_layout_sr + transform_func_rs = shared_16x16_to_mma_32x8_layout_rs + elif dtype_bits == 8: + transform_func_sr = shared_16x32_to_mma_32x16_layout + transform_func_rs = shared_32x16_to_mma_32x16_layout + else: + raise ValueError(f"Unsupported dtype {dtype}") + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs + + micro_size_s, _, micro_size_r = get_mma_micro_size(dtype) + + transform_func = transform_func + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +# Create a 16×16 matrix layout for ldmatrix operations +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) + +# Print the layout structure (optional for debugging) +print(base_layout) + +# Plot and save the layout visualization +plot_layout(base_layout, name="base_layout") +``` + +Output + +![base_layout](./images/base_layout.png) diff --git a/tilelang/original/examples/plot_layout/fragment_mfma_load_a.py b/tilelang/original/examples/plot_layout/fragment_mfma_load_a.py new file mode 100644 index 0000000000000000000000000000000000000000..d45cc227bc2d0fcef5f1d034c0ed51f62f4c571e --- /dev/null +++ b/tilelang/original/examples/plot_layout/fragment_mfma_load_a.py @@ -0,0 +1,127 @@ +import tilelang.language as T +from typing import Literal, Callable +from tvm.tir import IndexMap +from tilelang.intrinsics.utils import get_mma_micro_size + +from tilelang.intrinsics.mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_16x16_to_local_64x4_layout_A, + shared_16x32_to_local_64x8_layout_A, + shared_16x64_to_local_64x16_layout_A, +) + + +def make_mfma_load_base_layout( + dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False +) -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mfma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + matrix : Literal["A", "B"] + The mfma operand to be loaded. + k_dim : int + The k dimension of the mfma. + transposed : bool + Whether the matrix is transposed, by default False. + + Returns + ------- + T.Fragment + Describes how threads and indices in fragment are laid out. + + """ + + assert matrix in ["A", "B"], "matrix should be either A or B" + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + + if k_dim == 4: + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + elif k_dim == 16: + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + elif k_dim == 32: + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + elif k_dim == 64: + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +block_rows = 2 +block_cols = 2 +warp_rows = 2 +warp_cols = 2 +chunk = 2 + +from tilelang.tools import plot_layout + +# ldmatrix layout 16x16 +base_layout = make_mfma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) +print(base_layout) +plot_layout(base_layout, name="base_layout") + +# warp layout 32x32 +warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False) +print(warp_layout) +plot_layout(warp_layout, name="warp_layout") + +# block layout 64x32 +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols) +print(block_layout) +plot_layout(block_layout, name="block_layout") diff --git a/tilelang/original/examples/plot_layout/fragment_mma_load_a.py b/tilelang/original/examples/plot_layout/fragment_mma_load_a.py new file mode 100644 index 0000000000000000000000000000000000000000..df4a0b88701192c44e9743360ad7ead14d4f0dbd --- /dev/null +++ b/tilelang/original/examples/plot_layout/fragment_mma_load_a.py @@ -0,0 +1,122 @@ +import tilelang.language as T +from typing import Literal, Callable +from tvm import DataType +from tvm.tir import IndexMap +from tilelang.intrinsics.utils import get_mma_micro_size + + +def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + matrix : Literal["A", "B"] + The mma operand to be loaded. + transposed : bool + Whether the matrix is transposed, by default False. + + Returns + ------- + T.Fragment + Describes how threads and indices in fragment are laid out. + + """ + from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, + shared_16x8_to_mma_32x4_layout_sr_b, + shared_16x16_to_mma_32x8_layout_sr_b, + shared_16x32_to_mma_32x16_layout_sr_b, + ) + + assert matrix in ["A", "B"], "matrix should be either A or B" + dtype_bits = DataType(dtype).bits + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b + elif dtype_bits == 8: + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +block_rows = 2 +block_cols = 2 +warp_rows = 4 +warp_cols = 4 +chunk = 2 + +from tilelang.tools import plot_layout + +# ldmatrix layout 16x16 +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) +print(base_layout) +plot_layout(base_layout, name="base_layout") + +# warp layout 32x16 +warp_layout = base_layout.repeat([block_rows, 1], repeat_on_thread=True).replicate(block_cols) +print(warp_layout) +plot_layout(warp_layout, name="warp_layout") + +# block layout 128x32 +block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False) +print(block_layout) +plot_layout(block_layout, name="block_layout") diff --git a/tilelang/original/examples/plot_layout/images/base_layout.png b/tilelang/original/examples/plot_layout/images/base_layout.png new file mode 100644 index 0000000000000000000000000000000000000000..e8ebcf8b6971170b7dc2dfd5e66168bb487b7794 Binary files /dev/null and b/tilelang/original/examples/plot_layout/images/base_layout.png differ diff --git a/tilelang/original/examples/pytest.ini b/tilelang/original/examples/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..5f820048e6dfb0c195518233427628c5f4da027a --- /dev/null +++ b/tilelang/original/examples/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +norecursedirs = bitnet-1.58b diff --git a/tilelang/original/examples/quickstart.py b/tilelang/original/examples/quickstart.py new file mode 100644 index 0000000000000000000000000000000000000000..e99fc0dbceff115a0569495b563764170f05fa89 --- /dev/null +++ b/tilelang/original/examples/quickstart.py @@ -0,0 +1,87 @@ +import tilelang +import tilelang.language as T + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 1024 # M = T.dynamic("m") if you want to use dynamic shape +N = 1024 +K = 1024 +block_M = 128 +block_N = 128 +block_K = 32 + +# Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) +# Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = matmul_relu_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/tilelang/original/examples/rand/rand_uint.py b/tilelang/original/examples/rand/rand_uint.py new file mode 100644 index 0000000000000000000000000000000000000000..466a51b7a312643c66d02c7069db021f6f3cb036 --- /dev/null +++ b/tilelang/original/examples/rand/rand_uint.py @@ -0,0 +1,57 @@ +import tilelang +import tilelang.language as T +import torch +import triton +import triton.language as tl + + +@tilelang.jit +def tilelang_rand_1d(M=1024, seed=42): + num_per_thread = 128 + threads = 1 + blk_M = num_per_thread * threads + + @T.prim_func + def rand_kernel(A: T.Tensor((M,), "uint32")): + with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx: + tx = T.get_thread_binding() + T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread) + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + A[idx] = T.rng_rand() + + return rand_kernel + + +@triton.jit +def triton_rand_1d(X, M, elements_per_thread, seed): + pid = tl.program_id(0) + offset = pid * elements_per_thread + tl.arange(0, elements_per_thread) + + r0, r1, r2, r3 = tl.randint4x(seed, offset) + + base_idx = offset * 4 + tl.store(X + base_idx, r0, mask=base_idx < M) + tl.store(X + base_idx + 1, r1, mask=(base_idx + 1) < M) + tl.store(X + base_idx + 2, r2, mask=(base_idx + 2) < M) + tl.store(X + base_idx + 3, r3, mask=(base_idx + 3) < M) + + +def test_rand_1d(M, seed): + kernel = tilelang_rand_1d(M, seed) + tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda") + kernel(tilelang_result) + + triton_result = torch.empty(M, dtype=torch.uint32, device="cuda") + grid = (triton.cdiv(M, 128),) + triton_rand_1d[grid](triton_result, tl.constexpr(M), tl.constexpr(128 // 4), seed) + + torch.testing.assert_close(tilelang_result, triton_result) + + +if __name__ == "__main__": + test_rand_1d(1024, 42) + test_rand_1d(512, 123) + test_rand_1d(128, 0) diff --git a/tilelang/original/examples/seer_attention/block_sparse_attn_tilelang.py b/tilelang/original/examples/seer_attention/block_sparse_attn_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..25741f97cce73d6e8d9c06c9928590db5a53d68c --- /dev/null +++ b/tilelang/original/examples/seer_attention/block_sparse_attn_tilelang.py @@ -0,0 +1,254 @@ +import math +import torch + +import tilelang +import tilelang.language as T +import torch.nn.functional as F + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@tilelang.jit( + out_idx=[4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 0 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.int8 + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = T.ceildiv(seq_kv, block_N) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + past_len = seq_kv - seq_q + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + # Run tilelang kernel + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + # Verify accuracy + assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlen_lt_klen(): + # Config + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) + + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + print(kernel.get_kernel_source()) + tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) + + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + # Verify accuracy. + torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2) + + print("Pass topk sparse attention test with qlen < klen") + + +def main(): + test_topk_sparse_attention() + test_topk_sparse_attention_qlen_lt_klen() + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/seer_attention/block_sparse_attn_triton.py b/tilelang/original/examples/seer_attention/block_sparse_attn_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..b4cc3cd00c854cdde0757af38dcf6302976cd49f --- /dev/null +++ b/tilelang/original/examples/seer_attention/block_sparse_attn_triton.py @@ -0,0 +1,347 @@ +# ruff: noqa: E712 +import math +import torch + +import triton +import triton.language as tl +import torch.nn.functional as F + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + past_len, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + PAST_LEN, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + print("PAST_LEN", PAST_LEN) + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + print("downsample_len", downsample_len) + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + print("x_ds.shape", x_ds.shape) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + # print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + + # Run Triton kernel + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # print("ref_output", ref_output) + # print("triton_output", triton_output) + + # Verify accuracy + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlt_kl(): + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 64, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) + + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + # softmax scale + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + # Run Triton kernel. + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) + + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # Verify accuracy. + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" + + print("Pass topk sparse attention test with qlen < klen") + + +if __name__ == "__main__": + test_topk_sparse_attention() + test_topk_sparse_attention_qlt_kl() diff --git a/tilelang/original/examples/seer_attention/test_block_sparse_attn_tilelang.py b/tilelang/original/examples/seer_attention/test_block_sparse_attn_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..da175d05c7f71b66deb71b6785506fd98d85be54 --- /dev/null +++ b/tilelang/original/examples/seer_attention/test_block_sparse_attn_tilelang.py @@ -0,0 +1,12 @@ +import tilelang.testing + +import block_sparse_attn_tilelang + + +@tilelang.testing.requires_cuda +def test_block_sparse_attn_tilelang(): + block_sparse_attn_tilelang.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/sparse_tensorcore/test_example_sparse_tensorcore.py b/tilelang/original/examples/sparse_tensorcore/test_example_sparse_tensorcore.py new file mode 100644 index 0000000000000000000000000000000000000000..72292e44868dc30c7f2b6b5044a4449c5e9f559e --- /dev/null +++ b/tilelang/original/examples/sparse_tensorcore/test_example_sparse_tensorcore.py @@ -0,0 +1,13 @@ +import tilelang.testing +import tilelang +import tilelang_example_sparse_tensorcore + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tilelang_example_sparse_tensorcore(): + tilelang_example_sparse_tensorcore.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/tilelang/original/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py new file mode 100644 index 0000000000000000000000000000000000000000..14339ff02932819d4273acc818e8da0256354dcc --- /dev/null +++ b/tilelang/original/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -0,0 +1,117 @@ +import torch +import tilelang +from tilelang.utils.sparse import compress_sm90 +from tilelang.layout import make_cutlass_metadata_layout +from tilelang import language as T +import tilelang.testing + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) + B_shape = (K, N) + A_shared_shape = (block_M, block_K // 2) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // 8), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // 8), "uint8") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_K), + } + ) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // 8], E_shared) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"): + if shape[-1] % 4 != 0: + raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") + + full_tensor = torch.randn(shape, dtype=dtype, device=device) + group_count = shape[-1] // 4 + group_shape = shape[:-1] + (group_count, 4) + + rand_vals = torch.rand(group_shape, device=device) + topk_indices = rand_vals.topk(k=2, dim=-1).indices + mask = torch.zeros(group_shape, dtype=torch.bool, device=device) + mask.scatter_(-1, topk_indices, True) + mask = mask.view(shape) + + sparse_tensor = full_tensor * mask + return sparse_tensor + + +def run_gemm_sp( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_M, + block_N, + block_K, + num_stages, + num_threads, +): + kernel = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + ) + + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") + A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) + + C_sp = kernel(A_sparse, E, B).half() + C = torch.matmul(A, B) + torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3) + print("pass") + + +def main(): + run_gemm_sp(512, 1024, 768, T.float16, T.float16, T.float32, 128, 128, 128, 2, 128) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/topk/example_topk.py b/tilelang/original/examples/topk/example_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f0c8bfb28a116ca418633db3e0450d75bbf55e --- /dev/null +++ b/tilelang/original/examples/topk/example_topk.py @@ -0,0 +1,93 @@ +import tilelang +import tilelang.language as T +import torch +import itertools +import argparse + + +def get_configs(): + iter_params = dict( + blk_m=[64, 128, 256], + threads=[128, 256, 512], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[1, 2]) +def tl_topk( + M, + N, + topk, + blk_m, + threads=128, +): + dtype = T.float32 + + @T.prim_func + def topk_kernel( + logits: T.Tensor([M, N], dtype), + topk_gates: T.Tensor([M, topk], dtype), + topk_indices: T.Tensor([M, topk], T.int32), + ): + with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: + logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) + max_val = T.alloc_fragment([blk_m], dtype=dtype) + expand_max_idx = T.alloc_fragment([blk_m, N], T.int32) + max_idx = T.alloc_fragment([blk_m], T.int32) + + T.copy(logits[bx * blk_m, 0], logits_frag) + + for k in T.serial(topk): + T.fill(expand_max_idx, -1) + T.reduce_max(logits_frag, max_val, dim=1, clear=True) + + for i, j in T.Parallel(blk_m, N): + expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j]) + + T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True) + + for i, j in T.Parallel(blk_m, N): + logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j]) + + for i in T.Parallel(blk_m): + topk_gates[bx * blk_m + i, k] = max_val[i] + topk_indices[bx * blk_m + i, k] = max_idx[i] + + return topk_kernel + + +def ref_program(logits, top_k): + top_k_gates, top_k_indices = logits.topk(top_k, dim=1) + + return top_k_gates, top_k_indices.to(torch.int32) + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=320, help="num_tokens") + parser.add_argument("--N", type=int, default=128, help="num_experts") + parser.add_argument("--topk", type=int, default=6, help="topk") + parser.add_argument("--blk_m", type=int, default=64, help="blk_m") + args = parser.parse_args(argv) + M, N, topk, blk_m = args.M, args.N, args.topk, args.blk_m + + logits = torch.rand((M, N), device="cuda", dtype=torch.float32) + + kernel = tl_topk(M=M, N=N, topk=topk, blk_m=blk_m) + tl_gates, tl_indices = kernel(logits) + + torch_gates, torch_indices = ref_program(logits, topk) + + # test accuracy + torch.testing.assert_close(tl_gates, torch_gates) + torch.testing.assert_close(tl_indices, torch_indices) + + # profile + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench() + print(f"Tilelang latency: {tilelang_latency}") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/topk/test_topk_tilelang.py b/tilelang/original/examples/topk/test_topk_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..54de01143ffa23496ac65233155ce0f68bc28b5c --- /dev/null +++ b/tilelang/original/examples/topk/test_topk_tilelang.py @@ -0,0 +1,11 @@ +import tilelang.testing +import example_topk + + +@tilelang.testing.requires_cuda +def test_topk_tilelang(): + example_topk.main(argv=[]) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/examples/visual_layout_inference/visual_layout_inference.py b/tilelang/original/examples/visual_layout_inference/visual_layout_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa1eaf854ec283042b2f3a1c7d2c9d1ae1dd457 --- /dev/null +++ b/tilelang/original/examples/visual_layout_inference/visual_layout_inference.py @@ -0,0 +1,61 @@ +import tilelang +import tilelang.language as T + + +# use pass_configs to enable layout visualization +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg", + }, +) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(128, 128, 128, 32, 32, 32) + + import torch + + a = torch.randn(128, 128).cuda().half() + b = torch.randn(128, 128).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # print the layout visualization result and save figures to ./tmp. + """ + C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] + """ + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/warp_specialize/example_warp_specialize_flashmla.py b/tilelang/original/examples/warp_specialize/example_warp_specialize_flashmla.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd51aa7c9b885895f19aebe5bbc50f4687f14d --- /dev/null +++ b/tilelang/original/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -0,0 +1,372 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit(out_idx=[6]) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + h_dim = dim // 2 + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + # smem_sQ + Q_shared_l = T.alloc_shared([block_H, h_dim], dtype) + Q_shared_r = T.alloc_shared([block_H, h_dim], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype) + Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype) + + # smem_sK0 + KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype) + KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype) + K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype) + + # smem_sK1 + KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype) + KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype) + K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) + + # smem_sP0 + SP0_shared = T.alloc_shared([block_H, block_N], dtype) + + # smem_sP1 reuse Q_pe_shared + SP1_shared = Q_pe_shared + + # smem_sM + scores_max = T.alloc_shared([block_H], accum_dtype) + + # smem_sScale0 + scores_scale_0 = T.alloc_shared([block_H], accum_dtype) + # smem_sScale1 + scores_scale_1 = T.alloc_shared([block_H], accum_dtype) + + logsum = T.alloc_shared([block_H], accum_dtype) + + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype) + acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype) + scores_max_0 = T.alloc_fragment([block_H], accum_dtype) + scores_max_1 = T.alloc_fragment([block_H], accum_dtype) + + scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype) + + scores_sum_0 = T.alloc_fragment([block_H], accum_dtype) + scores_sum_1 = T.alloc_fragment([block_H], accum_dtype) + logsum_0 = T.alloc_fragment([block_H], accum_dtype) + logsum_1 = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = hid // (kv_group_num // block_H) + + T.annotate_layout( + { + O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), + O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), + } + ) + + # barriers_Q + q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) + + # barriers_K0 + kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128) + # barriers_K1 + kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128) + + # redundant barriers + score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128) + scale_1_ready_barrier = T.alloc_barrier(arrive_count=128) + p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128) + lse_0_ready_barrier = T.alloc_barrier(arrive_count=128) + lse_1_ready_barrier = T.alloc_barrier(arrive_count=128) + s_shared_ready_barrier = T.alloc_barrier(arrive_count=128) + + tx = T.get_thread_binding() + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.barrier_arrive(q_shared_ready_barrier) + T.barrier_wait(q_shared_ready_barrier, 0) + + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(seqlen_kv, (block_N * 2)) + + if tx < 128: + T.copy(Q_pe_shared, Q_pe_local_0) + T.fill(acc_o_l, 0) + T.fill(logsum_0, 0) + + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.barrier_arrive(kv_shared_1_l_is_ready) + + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.barrier_arrive(kv_shared_1_r_is_ready) + + T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1) + T.barrier_arrive(kv_shared_1_pe_is_ready) + + for k in T.serial(loop_range): + T.barrier_wait(kv_shared_0_l_is_ready, k % 2) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1) + T.barrier_wait(kv_shared_0_r_is_ready, k % 2) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) + + T.barrier_wait(kv_shared_0_pe_is_ready, k % 2) + T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + # Step 3. + T.copy(scores_max, scores_max_0) + T.copy(scores_max_0, scores_max_prev_0) + T.fill(scores_max_0, -T.infinity(accum_dtype)) + T.reduce_max(acc_s_0, scores_max_0, dim=1, clear=False) + T.copy(scores_max_0, scores_max) + + # Step 4. + for i, j in T.Parallel(block_H, block_N): + acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale) + for i in T.Parallel(block_H): + scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale) + + T.reduce_sum(acc_s_0, scores_sum_0, dim=1) + + # Step 5. + T.copy(acc_s_0, acc_s_0_cast) + + for i, j in T.Parallel(block_H, h_dim): + acc_o_l[i, j] *= scores_scale_0[i] + + for i in T.Parallel(block_H): + logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i] + + # Step 6. + T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l) + T.barrier_arrive(score_max_0_ready_barrier) + + T.barrier_wait(scale_1_ready_barrier, k % 2) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l) + T.barrier_arrive(kv_shared_0_l_is_ready) + + # Step 11. + for i, j in T.Parallel(block_H, block_N): + SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i] + + T.barrier_arrive(p0_1_1_ready_barrier) + + # Step 13. + for i, j in T.Parallel(block_H, h_dim): + acc_o_l[i, j] *= scores_scale_1[i] + for i in T.Parallel(block_H): + logsum_0[i] = logsum_0[i] * scores_scale_1[i] + T.barrier_wait(s_shared_ready_barrier, k % 2) + + # Step 14. + T.gemm(SP1_shared, KV_shared_1_l, acc_o_l) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.barrier_arrive(kv_shared_1_l_is_ready) + + T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1) + T.barrier_arrive(kv_shared_1_pe_is_ready) + + T.copy(logsum_0, logsum) + T.barrier_arrive(lse_0_ready_barrier) + T.barrier_wait(lse_1_ready_barrier, 0) + for i, j in T.Parallel(block_H, h_dim): + acc_o_l[i, j] /= logsum[i] + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim]) + + else: + T.copy(Q_pe_shared, Q_pe_local_1) + T.fill(acc_o_r, 0) + T.fill(logsum_1, 0) + + T.copy(KV[bid, :block_N, cur_kv_head, :h_dim], KV_shared_0_l) + T.barrier_arrive(kv_shared_0_l_is_ready) + T.copy(KV[bid, :block_N, cur_kv_head, h_dim:], KV_shared_0_r) + T.barrier_arrive(kv_shared_0_r_is_ready) + T.copy(K_pe[bid, :block_N, cur_kv_head, :], K_pe_shared_0) + T.barrier_arrive(kv_shared_0_pe_is_ready) + + for k in T.serial(loop_range): + # Step 2. + T.barrier_wait(kv_shared_1_l_is_ready, k % 2) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1) + + T.barrier_wait(kv_shared_1_r_is_ready, k % 2) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) + + T.barrier_wait(kv_shared_1_pe_is_ready, k % 2) + T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + # Step 7. + T.barrier_wait(score_max_0_ready_barrier, k % 2) + + T.copy(scores_max, scores_max_prev_1) + T.fill(scores_max_1, -T.infinity(accum_dtype)) + T.reduce_max(acc_s_1, scores_max_1, dim=1, clear=False) + T.copy(scores_max_1, scores_max) + + for i in T.Parallel(block_H): + scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale) + + # Step 8. + for i, j in T.Parallel(block_H, block_N): + acc_s_1[i, j] = T.exp2(acc_s_1[i, j] * scale - scores_max[i] * scale) + + # Step 9. + T.reduce_sum(acc_s_1, scores_sum_1, dim=1) + + for i, j in T.Parallel(block_H, h_dim): + acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i]) + + for i in T.Parallel(block_H): + logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i] + + T.barrier_arrive(scale_1_ready_barrier) + + # Step 10. compute O1 with KV_shared_1_rd + T.copy(acc_s_1, acc_s_1_cast) + T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1) + T.copy(acc_s_1_cast, SP1_shared) + T.barrier_arrive(s_shared_ready_barrier) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.barrier_arrive(kv_shared_1_r_is_ready) + + T.barrier_wait(p0_1_1_ready_barrier, k % 2) + # Step 12. + T.gemm(SP0_shared, KV_shared_0_r, acc_o_r) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r) + T.barrier_arrive(kv_shared_0_r_is_ready) + + T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0) + T.barrier_arrive(kv_shared_0_pe_is_ready) + + T.barrier_wait(lse_0_ready_barrier, 0) + for i in T.Parallel(block_H): + logsum[i] += logsum_1[i] + T.barrier_arrive(lse_1_ready_barrier) + for i, j in T.Parallel(block_H, h_dim): + acc_o_r[i, j] /= logsum[i] + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:]) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = 64 + num_split = 1 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2aa00d929d0d91895cc98ea3668955a8883e8f --- /dev/null +++ b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -0,0 +1,86 @@ +import tilelang +import tilelang.language as T + +tilelang.disable_cache() + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + num_stages = 2 + mbarrier_list = [128, 128] * num_stages + + @T.prim_func + def main( + A: T.Tensor[(M, K), dtype], + B: T.Tensor[(K, N), dtype], + C: T.Tensor[(M, N), dtype], + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((num_stages, block_M, block_K), dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # create mbarrier for tma + T.create_list_of_mbarrier(mbarrier_list) + + with T.ws(0): + T.clear(C_local) + + for ko in range(T.ceildiv(K, block_K)): + with T.ws(1): + T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1) + T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :]) + T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :]) + T.mbarrier_arrive(mbarrier=ko % num_stages) + with T.ws(0): + T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) + T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local) + T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages) + + with T.ws(0): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=16384, N=16384, K=16384): + tilelang.disable_cache() + block_M = 128 + block_N = 128 + block_K = 64 + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + print(jit_kernel.get_kernel_source()) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py new file mode 100644 index 0000000000000000000000000000000000000000..7b22784323ba327d0054a0f46e53bd7e6eb6acdc --- /dev/null +++ b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -0,0 +1,78 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + data_is_ready = T.alloc_barrier(arrive_count=128) + compute_is_done = T.alloc_barrier(arrive_count=128) + + with T.ws(1): + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + with T.ws(0): + T.barrier_wait(compute_is_done, (ko + 1) % 2) + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.barrier_arrive(data_is_ready) + with T.ws(1): + T.barrier_wait(data_is_ready, ko % 2) + T.gemm(A_shared, B_shared, C_local) + T.barrier_arrive(compute_is_done) + + with T.ws(1): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=1024, N=1024, K=1024): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) + + import torch + + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py new file mode 100644 index 0000000000000000000000000000000000000000..02d88019c7e1793c824e088f54d8cd3b3d871212 --- /dev/null +++ b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -0,0 +1,79 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + data_is_ready = T.alloc_barrier(arrive_count=128) + compute_is_done = T.alloc_barrier(arrive_count=128) + + with T.ws(0): + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + with T.ws(1): + T.barrier_wait(compute_is_done, (ko + 1) % 2) + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.barrier_arrive(data_is_ready) + with T.ws(0): + T.barrier_wait(data_is_ready, ko % 2) + T.gemm(A_shared, B_shared, C_local) + T.barrier_arrive(compute_is_done) + + with T.ws(0): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5468aa6eace4ca259e885db9bd33d9e6a77b459a --- /dev/null +++ b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -0,0 +1,96 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, +) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + warp_group_num = 2 + threads = 128 * warp_group_num + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") + B_shared_g0 = T.alloc_shared((block_K, block_N // warp_group_num), dtype, "shared") + B_shared_g1 = T.alloc_shared((block_K, block_N // warp_group_num), dtype, "shared") + + C_local_g0 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype) + C_local_g1 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype) + + with T.ws(1): + T.clear(C_local_g1) + with T.ws(0): + T.clear(C_local_g0) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + with T.ws(1): + T.copy(B[ko * block_K, bx * block_N], B_shared_g1) + T.gemm(A_shared, B_shared_g1, C_local_g1) + with T.ws(0): + T.copy(B[ko * block_K, bx * block_N + block_N // warp_group_num], B_shared_g0) + T.gemm(A_shared, B_shared_g0, C_local_g0) + + with T.ws(1): + T.copy(C_local_g1, C[by * block_M, bx * block_N]) + with T.ws(0): + T.copy(C_local_g0, C[by * block_M, bx * block_N + block_N // warp_group_num]) + + return main + + +def main(): + M = 128 + N = 128 + K = 64 + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + print(jit_kernel.get_kernel_source()) + # 3. Test the kernel in Python with PyTorch data + import torch + + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + print(c) + + # Reference multiplication using PyTorch + ref_c = a @ b + print(ref_c) + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py new file mode 100644 index 0000000000000000000000000000000000000000..31d156f327a6ccde2c6d1ac236b4622e60048dfe --- /dev/null +++ b/tilelang/original/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -0,0 +1,82 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor[(M, K), dtype], + B: T.Tensor[(K, N), dtype], + C: T.Tensor[(M, N), dtype], + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # create mbarrier for tma + data_is_ready = T.alloc_barrier(arrive_count=128) + compute_is_done = T.alloc_barrier(arrive_count=128) + + with T.ws(0): + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + with T.ws(1): + T.barrier_wait(compute_is_done, (ko + 1) % 2) + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.barrier_arrive(data_is_ready) + with T.ws(0): + T.barrier_wait(data_is_ready, ko % 2) + T.gemm(A_shared, B_shared, C_local) + T.barrier_arrive(compute_is_done) + + with T.ws(0): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + # 3. Test the kernel in Python with PyTorch data + import torch + + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/examples/warp_specialize/test_example_warp_specialize.py b/tilelang/original/examples/warp_specialize/test_example_warp_specialize.py new file mode 100644 index 0000000000000000000000000000000000000000..dee507790b129d462d614d183dcd322910ff0f5d --- /dev/null +++ b/tilelang/original/examples/warp_specialize/test_example_warp_specialize.py @@ -0,0 +1,42 @@ +import tilelang.testing + +import example_warp_specialize_gemm_barrierpipe_stage2 +import example_warp_specialize_gemm_copy_0_gemm_1 +import example_warp_specialize_gemm_copy_1_gemm_0 +import example_warp_specialize_gemm_softpipe_stage2 + +# TODO: skip for now as non-deterministic on H20 +# CC @cunxiao +# @tilelang.testing.requires_cuda +# @tilelang.testing.requires_cuda_compute_version_eq(9, 0) +# def test_example_warp_specialize_flashmla(): +# import example_warp_specialize_flashmla +# example_warp_specialize_flashmla.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_barrierpipe_stage2(): + example_warp_specialize_gemm_barrierpipe_stage2.main(M=1024, N=1024, K=1024) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_copy_0_gemm_1(): + example_warp_specialize_gemm_copy_0_gemm_1.main(M=1024, N=1024, K=1024) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_copy_1_gemm_0(): + example_warp_specialize_gemm_copy_1_gemm_0.main(M=1024, N=1024, K=1024) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_softpipe_stage2(): + example_warp_specialize_gemm_softpipe_stage2.main(M=1024, N=1024, K=1024) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/images/MatmulExample.png b/tilelang/original/images/MatmulExample.png new file mode 100644 index 0000000000000000000000000000000000000000..555ae30a75b2486bffb8acf27f72802d2c96ec3d Binary files /dev/null and b/tilelang/original/images/MatmulExample.png differ diff --git a/tilelang/original/images/MatmulExample.svg b/tilelang/original/images/MatmulExample.svg new file mode 100644 index 0000000000000000000000000000000000000000..6e20daf554d6ebf18bb28af827f8822238861cf2 --- /dev/null +++ b/tilelang/original/images/MatmulExample.svg @@ -0,0 +1 @@ +A_shared=T.alloc_shared((block_M,block_K))B_shared=T.alloc_shared((block_K,block_N))C_local=T.alloc_fragment((block_M,block_N),accum_dtype)importtilelang.languageasTdefMatmul(A:T.Buffer,B:T.Buffer,C:T.Buffer):withT.Kernel(ceildiv(N,block_N),ceildiv(M,block_M),threads=128)as(bx,by):T.clear(C_local)forkinT.Pipelined(ceildiv(K,block_K),num_stages=3):T.copy(A[by*block_M,k*block_K],A_shared)T.copy(B[k*block_K,bx*block_N],B_shared)T.gemm(A_shared,B_shared,C_local)Kernel Context InitializationBuffer AllocationRegisterInitialize Accumulate Buffer with ZeroMain Loop with Pipeline AnnotationT.copy(C_local,C[by*block_M,bx*block_N])Write Back to Global MemoryCopy Data from Global to Shared MemoryGEMMSharedMemoryGlobal MemoryShared MemoryRegister Files(a) Efficient GEMM with Multi-Level Tiling on GPUs(b) Describing Tiled GPU GEMM with TileLang \ No newline at end of file diff --git a/tilelang/original/images/logo-row.svg b/tilelang/original/images/logo-row.svg new file mode 100644 index 0000000000000000000000000000000000000000..633243f3a9a003a903b859e8d8da5273b0f4cbf3 --- /dev/null +++ b/tilelang/original/images/logo-row.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tilelang/original/images/mha_performance_h100.png b/tilelang/original/images/mha_performance_h100.png new file mode 100644 index 0000000000000000000000000000000000000000..54c7cf94bf632badc2732716891b808b649de68d Binary files /dev/null and b/tilelang/original/images/mha_performance_h100.png differ diff --git a/tilelang/original/images/op_benchmark_a100_wq_gemv.png b/tilelang/original/images/op_benchmark_a100_wq_gemv.png new file mode 100644 index 0000000000000000000000000000000000000000..c31c80e50f8cb792aa637f380e524f4e190d3894 Binary files /dev/null and b/tilelang/original/images/op_benchmark_a100_wq_gemv.png differ diff --git a/tilelang/original/images/op_benchmark_consistent_gemm_fp16.png b/tilelang/original/images/op_benchmark_consistent_gemm_fp16.png new file mode 100644 index 0000000000000000000000000000000000000000..840e423e7199a96e8127cfe2750f7ebb60058bb3 Binary files /dev/null and b/tilelang/original/images/op_benchmark_consistent_gemm_fp16.png differ diff --git a/tilelang/original/images/op_benchmark_h100.png b/tilelang/original/images/op_benchmark_h100.png new file mode 100644 index 0000000000000000000000000000000000000000..3480ec522c90475c67db341e78c2d4b28b6f7c83 Binary files /dev/null and b/tilelang/original/images/op_benchmark_h100.png differ diff --git a/tilelang/original/images/op_benchmark_mi300_fp16_gemm_normalized_latency.png b/tilelang/original/images/op_benchmark_mi300_fp16_gemm_normalized_latency.png new file mode 100644 index 0000000000000000000000000000000000000000..90839aea728155fc51f944e04b37962b78e9f8c2 Binary files /dev/null and b/tilelang/original/images/op_benchmark_mi300_fp16_gemm_normalized_latency.png differ diff --git a/tilelang/original/src/ir.cc b/tilelang/original/src/ir.cc new file mode 100644 index 0000000000000000000000000000000000000000..82a94cb8e4ff71e795518151a6d8dfde7636a5bb --- /dev/null +++ b/tilelang/original/src/ir.cc @@ -0,0 +1,416 @@ +/*! + * \file tl/ir.cc + * \brief Extension for the tvm script frontend. + * + */ + +#include "./transform/common/attr.h" +#include "op/builtin.h" +#include "tvm/ffi/any.h" +#include + +#include "support/ffi_aliases.h" +#include +#include +#include + +#include + +namespace tvm { +namespace tl { + +using namespace script::ir_builder::tir; + +static Var CreateEnvThread(String name, String thread_tag, DataType dtype) { + using namespace tvm::tir; + using namespace tvm::script::ir_builder; + IterVar iter_var(Range{nullptr}, Var(std::move(name), dtype), + tvm::tir::IterVarType::kThreadIndex, std::move(thread_tag)); + Var var = iter_var->var; + if (Optional opt_frame = + IRBuilder::Current()->FindFrame()) { + opt_frame.value()->env_threads.Set(var, iter_var); + } else { + LOG(FATAL) << "EnvThread can only be used inside a PrimFunc"; + } + return var; +} + +static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { + using namespace tvm::tir; + Var var = Var(name, dom->dtype); + // Create a frame that represents a loop over the given domain. + ObjectPtr n = tvm::ffi::make_object(); + n->vars.push_back(var); + n->doms.push_back(Range(0, dom)); + n->f_make_for_loop = [](const Array &vars, const Array &doms, + const Array> &steps, + Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), 1); + ICHECK_EQ(doms.size(), 1); + Optional step = + !steps.empty() ? steps[0] : Optional(std::nullopt); + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body, + /*thread_binding=*/std::nullopt, + /*annotations=*/tvm::ffi::Map{}, + /*step=*/step); + }; + return ForFrame(n); +} + +ForFrame ParallelFor(const Array &extents, + const Map &annotations) { + using namespace tvm::tir; + ObjectPtr n = tvm::ffi::make_object(); + n->vars.reserve(extents.size()); + n->doms.reserve(extents.size()); + for (const auto &extent : extents) { + DataType dtype = extent.dtype(); + n->vars.push_back(Var("v", extent.dtype())); + n->doms.push_back(Range(make_const(dtype, 0), extent)); + } + n->f_make_for_loop = + [annotations](const Array &vars, const Array &doms, + const Array> &steps, Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); + for (int i = n - 1; i >= 0; --i) { + Range dom = doms[i]; + Var var = vars[i]; + Optional step = + i < steps.size() ? steps[i] : Optional(std::nullopt); + body = For(var, dom->min, dom->extent, ForKind::kParallel, body, + /*thread_binding=*/std::nullopt, /*annotations=*/annotations, + /*step=*/step); + } + return body; + }; + return ForFrame(n); +} + +ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, + const Array &order, + const Array &stages, + const Array> &sync, + const Array> &groups) { + using namespace tvm::tir; + ObjectPtr n = tvm::ffi::make_object(); + DataType dtype = stop.dtype(); + n->vars.push_back(Var("v", dtype)); + n->doms.push_back(Range(std::move(start), stop)); + n->f_make_for_loop = [=](const Array &vars, const Array &doms, + const Array> &steps, + Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); + ICHECK(n == 1); + Map anno; + if (num_stages > 0) + anno.Set("num_stages", PrimExpr(num_stages)); + if (!order.empty()) + anno.Set("tl_pipeline_order", order); + if (!stages.empty()) + anno.Set("tl_pipeline_stage", stages); + if (!sync.empty()) + anno.Set("tl_pipeline_sync", sync); + if (!groups.empty()) + anno.Set("tl_pipeline_group", groups); + Optional step = + !steps.empty() ? steps[0] : Optional(std::nullopt); + body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body, + /*thread_binding=*/std::nullopt, /*annotations=*/anno, + /*step=*/step); + return body; + }; + return ForFrame(n); +} + +ForFrame PersistentFor(const Array &domain, const PrimExpr &wave_size, + const PrimExpr &index, PrimExpr group_size) { + using namespace tvm::tir; + ICHECK(!domain.empty()); + ObjectPtr n = tvm::ffi::make_object(); + n->vars.reserve(domain.size()); + n->doms.reserve(domain.size()); + PrimExpr domain_size = domain[0]; + for (int i = 1; i < domain.size(); i++) { + domain_size *= domain[i]; + } + + auto waves = ceildiv(domain_size, wave_size); + auto loop_var = Var("w", waves.dtype()); + group_size = min(group_size, domain[domain.size() - 1]); + Array coord_vars; + + for (int i = 0; i < domain.size(); ++i) { + DataType dtype = domain[i].dtype(); + Var coord("v" + std::to_string(i), dtype); + coord_vars.push_back(coord); + n->vars.push_back(coord); + n->doms.push_back(Range(make_const(dtype, 0), domain[i])); + } + + Array grouped_domain; + grouped_domain.push_back(truncdiv(domain[domain.size() - 1], group_size)); + for (int i = 0; i < domain.size() - 1; ++i) { + grouped_domain.push_back(domain[i]); + } + grouped_domain.push_back(group_size); + + n->f_make_for_loop = [=](const Array &vars, const Array &doms, + const Array> &steps, + Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + Map anno; + Array idxs(grouped_domain.size(), PrimExpr()); + PrimExpr rem = loop_var * wave_size + index; + + for (int i = grouped_domain.size() - 1; i >= 1; --i) { + idxs.Set(i, truncmod(rem, grouped_domain[i])); + rem = truncdiv(rem, grouped_domain[i]); + } + idxs.Set(0, rem); + + auto out_if = tvm::tir::IfThenElse( + domain_size <= (loop_var * wave_size + index), + tvm::tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tl::loop_break(), {})), + Stmt()); + + arith::Analyzer analyzer; + Stmt new_body = body; + if (analyzer.CanProveGreaterEqual(waves, 2)) { + new_body = SeqStmt({out_if, body}); + } + Optional step = + !steps.empty() ? steps[0] : Optional(std::nullopt); + Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body, + /*thread_binding=*/std::nullopt, /*annotations=*/anno, + /*step=*/step); + for (int i = 0; i < vars.size() - 1; ++i) { + outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer); + } + outer = tvm::tir::LetStmt(vars[vars.size() - 1], + idxs[0] * group_size + idxs[vars.size()], outer); + return outer; + }; + + return ForFrame(n); +} + +/*! + * \brief A frame that represents a kernel launch. + * + * \sa KernelLaunchFrameNode + */ +class KernelLaunchFrameNode : public TIRFrameNode { +public: + Array frames; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro( + "frames", &KernelLaunchFrameNode::frames); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame", + KernelLaunchFrameNode, TIRFrameNode); + +public: + TVM_DLL void EnterWithScope() final { + for (auto frame = frames.begin(); frame != frames.end(); ++frame) + (*frame)->EnterWithScope(); + } + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + TVM_DLL void ExitWithScope() final { + for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame) + (*frame)->ExitWithScope(); + } +}; + +/*! + * \brief Managed reference to KernelLaunchFrameNode. + * + * \sa KernelLaunchFrameNode + */ +class KernelLaunchFrame : public TIRFrame { +public: + explicit KernelLaunchFrame(ObjectPtr data) + : TIRFrame(::tvm::ffi::UnsafeInit{}) { + ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame, + KernelLaunchFrameNode); +}; + +KernelLaunchFrame KernelLaunch(const Array &grid_size, + const Optional> &block_size_opt, + const Map &attrs) { + ObjectPtr n = + tvm::ffi::make_object(); + + // If the kernel is a CPU kernel, we don't need to launch any threads. + bool is_cpu_kernel_frame = + attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame); + + auto block_size = block_size_opt.value_or(Array()); + + if (is_cpu_kernel_frame) { + // Launch CPU Kernel + ICHECK(grid_size.size() >= 0); + ICHECK(block_size.empty()) << "CPU kernel cannot have block size"; + ICHECK(attrs.defined()); + // create grid loop var + for (int i = 0; i < grid_size.size(); i++) { + n->frames.push_back( + MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i])); + } + } else { + // Launch GPU Kernel + ICHECK(grid_size.size() <= 3); + if (!grid_size.empty()) + n->frames.push_back(LaunchThread( + CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()), + grid_size[0])); + if (grid_size.size() > 1) + n->frames.push_back(LaunchThread( + CreateEnvThread("by", "blockIdx.y", grid_size[1].dtype()), + grid_size[1])); + if (grid_size.size() > 2) + n->frames.push_back(LaunchThread( + CreateEnvThread("bz", "blockIdx.z", grid_size[2].dtype()), + grid_size[2])); + if (block_size.defined()) { + ICHECK(block_size.size() <= 3); + if (!block_size.empty()) { + n->frames.push_back(LaunchThread( + CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()), + block_size[0])); + } + if (block_size.size() > 1) { + n->frames.push_back(LaunchThread( + CreateEnvThread("ty", "threadIdx.y", block_size[1].dtype()), + block_size[1])); + } + if (block_size.size() > 2) { + n->frames.push_back(LaunchThread( + CreateEnvThread("tz", "threadIdx.z", block_size[2].dtype()), + block_size[2])); + } + } + } + + if (attrs.defined()) { + auto empty_block = tvm::script::ir_builder::tir::Block(MainBlockName); + empty_block->annotations = attrs; + n->frames.push_back(empty_block); + } else { + n->frames.push_back(tvm::script::ir_builder::tir::Block(MainBlockName)); + } + + return KernelLaunchFrame(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tl.Parallel", ParallelFor) + .def("tl.Pipelined", PipelinedFor) + .def("tl.Persistent", PersistentFor) + .def("tl.KernelLaunch", KernelLaunch); +} + +class WarpSpecializeFrameNode : public TIRFrameNode { +public: + Array frames; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro( + "frames", &WarpSpecializeFrameNode::frames); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame", + WarpSpecializeFrameNode, TIRFrameNode); + +public: + TVM_DLL void EnterWithScope() final { + for (auto frame = frames.begin(); frame != frames.end(); ++frame) + (*frame)->EnterWithScope(); + } + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + TVM_DLL void ExitWithScope() final { + for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame) + (*frame)->ExitWithScope(); + } +}; + +class WarpSpecializeFrame : public TIRFrame { +public: + explicit WarpSpecializeFrame(ObjectPtr data) + : TIRFrame(::tvm::ffi::UnsafeInit{}) { + ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame, + WarpSpecializeFrameNode); +}; + +WarpSpecializeFrame WarpSpecialize(const Array &warp_group_ids, + const PrimExpr &thread_idx, + int warp_group_size = 128) { + ObjectPtr n = + tvm::ffi::make_object(); + PrimExpr condition; + std::vector warp_groups; + warp_groups.reserve(warp_group_ids.size()); + for (int i = 0; i < warp_group_ids.size(); i++) { + warp_groups.push_back(Downcast(warp_group_ids[i])->value); + } + std::sort(warp_groups.begin(), warp_groups.end()); + + // Merge consecutive groups + std::vector> merged; + for (int group : warp_groups) { + if (merged.empty() || group != merged.back().second) { + merged.emplace_back(group, group + 1); + } else { + merged.back().second = group + 1; + } + } + + for (const auto &[start, end] : merged) { + PrimExpr min_bound = IntImm(thread_idx.dtype(), start) * warp_group_size; + PrimExpr max_bound = IntImm(thread_idx.dtype(), end) * warp_group_size; + PrimExpr range_cond = (thread_idx >= min_bound) && (thread_idx < max_bound); + + if (condition.defined()) { + condition = tir::Or(condition, range_cond); + } else { + condition = range_cond; + } + } + IfFrame if_frame = If(condition); + AttrFrame attr_frame = Attr(Integer(0), "warp_specialize", Integer(1)); + n->frames.push_back(if_frame); + n->frames.push_back(Then()); + n->frames.push_back(attr_frame); + return WarpSpecializeFrame(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize); + KernelLaunchFrameNode::RegisterReflection(); + WarpSpecializeFrameNode::RegisterReflection(); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/gemm_layouts.cc b/tilelang/original/src/layout/gemm_layouts.cc new file mode 100644 index 0000000000000000000000000000000000000000..01d0ae66263c44ccc3c9a209e969915231b7ea13 --- /dev/null +++ b/tilelang/original/src/layout/gemm_layouts.cc @@ -0,0 +1,814 @@ +/*! + * \file layout/gemm_layouts.cc + * \brief Define Layout used in MMA and other operations. + * + */ + +#include + +#include + +#include "layout.h" + +namespace tvm { +namespace tl { + +IterVar make_itervar(std::string name, PrimExpr dom) { + Var var = Var(name, dom->dtype); + return IterVar(Range(0, dom), var, IterVarType::kDataPar); +} + +Fragment makeGemmFragment8x4() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 4); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i; + PrimExpr index = FloorMod(j->var, 1); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragment8x8() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 8); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(j->var, 2) + 4 * i; + PrimExpr index = FloorMod(j->var, 2); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragment8x16() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 16); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i; + PrimExpr index = FloorMod(j->var, 4); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragment8x8Transposed() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 8); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j; + PrimExpr index = FloorMod(i->var, 2); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +/* +From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator +./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16 +--detail-instruction +*/ +Fragment makeGemmFragmentAB16x16CDNA(const int k_pack) { + IterVar i = make_itervar("i", 16); + IterVar j = make_itervar("j", 16 * k_pack); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = 16 * FloorDiv(j->var, 4 * k_pack) + i; + PrimExpr index = FloorMod(j->var, 4 * k_pack); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragmentAB16x16CDNATransposed(const int k_pack) { + IterVar i = make_itervar("i", 16 * k_pack); + IterVar j = make_itervar("j", 16); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = 16 * FloorDiv(i->var, 4 * k_pack) + j; + PrimExpr index = FloorMod(i->var, 4 * k_pack); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragmentAB16x32CDNA(const int k_pack) { + IterVar i = make_itervar("i", 16); + IterVar j = make_itervar("j", 32 * k_pack); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = 16 * FloorDiv(j->var, 8 * k_pack) + i; + PrimExpr index = FloorMod(j->var, 8 * k_pack); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragmentAB16x32CDNATransposed(const int k_pack) { + IterVar i = make_itervar("i", 32 * k_pack); + IterVar j = make_itervar("j", 16); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = 16 * FloorDiv(i->var, 8 * k_pack) + j; + PrimExpr index = FloorMod(i->var, 8 * k_pack); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragmentC16x16CDNA() { + IterVar i = make_itervar("i", 16); + IterVar j = make_itervar("j", 16); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i; + PrimExpr index = FloorMod(j->var, 4); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, + const int warp_m, const int warp_n) { + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0); + ICHECK(warp_n % 8 == 0); + auto base_layout = makeGemmFragment8x8(); + auto warp_layout = + base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); + auto block_layout = + warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false); + return block_layout; +} + +Fragment makeGemmFragmentC(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size) { + if (element_size == 64) + return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n); + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + ICHECK(warp_n % 8 == 0) << "warp_n=" << warp_n; + auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false); + auto warp_layout = + base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); + auto block_layout = + warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false); + return block_layout; +} + +Fragment makeGemmSparseFragmentC(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size) { + if (element_size == 64) { + ICHECK(false) << "Not supported"; + } + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + ICHECK(warp_n % 8 == 0) << "warp_n=" << warp_n; + auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false); + // NOTE: This func wasn't implemented by following the CUTLASS 2 iterator + // but by inspecting the output, it appears that we first need to + // repeat the warp layout while avoiding duplicate thread mappings. + auto warp_layout = + base_layout->Repeat({warp_m / 16, warp_n / 8}, false, false); + auto block_layout = + warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); + return block_layout; +} + +Fragment makeGemmFragmentCDCU(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size) { + if (element_size == 64) + LOG(FATAL) << "Not supported"; + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; + auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false); + auto warp_layout = + base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false); + auto block_layout = + warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); + return block_layout; +} + +Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size) { + if (element_size == 64) + LOG(FATAL) << "Not supported"; + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; + auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false); + auto warp_layout = + base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true); + auto block_layout = + warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); + return block_layout; +} + +Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size) { + ICHECK(block_m % warp_m == 0); + ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + + auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, + false); // 16 x N (1 warp) + auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, + true, false); // 16*Y x N (Y warp) + return block_layout->Repeat({warp_m / 16, 1}, false, false); +} + +Fragment makeGemmFragmentA(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n, const int element_size, + bool transposed) { + // assume not transposed + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0); + ICHECK(block_k % 16 == 0); + // Only support 8-bit and 16-bit + ICHECK(element_size == 8 || element_size == 16 || element_size == 32) + << "unsupported element bitwidth=" << element_size; + + if (transposed) { + auto base_layout = + makeGemmFragment8x8Transposed()->Repeat({2, 2}, false, true); + auto warp_layout = base_layout->Repeat({1, block_m / warp_m}, true, false) + ->Replicate(block_n / warp_n); + auto block_layout = + warp_layout->Repeat({block_k / 16, warp_m / 16}, false, true); + return block_layout; + } else { + if (element_size == 8) { + auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true) + ->Replicate(block_n / warp_n); + auto block_layout = + warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false); + return block_layout; + } else if (element_size == 16) { + auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true) + ->Replicate(block_n / warp_n); + auto block_layout = + warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false); + return block_layout; + } else if (element_size == 32) { + auto base_layout = makeGemmFragment8x4()->Repeat({2, 2}, false, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true) + ->Replicate(block_n / warp_n); + auto block_layout = + warp_layout->Repeat({warp_m / 16, block_k / 8}, false, false); + return block_layout; + } else { + ICHECK(0); + return Fragment(); + } + } +} + +Fragment makeGemmFragmentB(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n, bool transposed) { + // transposed + ICHECK(warp_n % 8 == 0); + ICHECK(block_k % 16 == 0); + if (transposed) { + auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false); + auto warp_layout = base_layout->Replicate(block_m / warp_m) + ->Repeat({block_n / warp_n, 1}, true, false); + auto block_layout = + warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false); + return block_layout; + } else { + auto base_layout = + makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false); + auto warp_layout = base_layout->Replicate(block_m / warp_m) + ->Repeat({1, block_n / warp_n}, true); + auto block_layout = + warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true); + return block_layout; + } +} + +Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n, const int element_size, + const int k_pack, bool transposed) { + // assume not transposed + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0); + const int mfma_k = k_pack * (element_size == 16 ? 16 : 32); + ICHECK(block_k % mfma_k == 0); + ICHECK(element_size == 8 || element_size == 16) + << "element bitwidth=" << element_size; + if (transposed) { + auto base_layout = + element_size == 16 + ? makeGemmFragmentAB16x16CDNATransposed(k_pack)->Repeat( + {1, 1}, false, false) + : makeGemmFragmentAB16x32CDNATransposed(k_pack)->Repeat( + {1, 1}, false, false); + auto warp_layout = + base_layout->Repeat({block_k / mfma_k, warp_m / 16}, false, true); + auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true) + ->Replicate(block_n / warp_n); + return block_layout; + } else { + auto base_layout = + element_size == 16 + ? makeGemmFragmentAB16x16CDNA(k_pack)->Repeat({1, 1}, false, false) + : makeGemmFragmentAB16x32CDNA(k_pack)->Repeat({1, 1}, false, false); + auto warp_layout = + base_layout->Repeat({warp_m / 16, block_k / mfma_k}, false, false); + auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true) + ->Replicate(block_n / warp_n); + return block_layout; + } +} + +Fragment makeGemmFragment32x32(int element_size) { + IterVar i = make_itervar("i", 32); + IterVar j = make_itervar("j", 32); + IterVar rep = make_itervar("rep", 1); + ICHECK(element_size == 16 || element_size == 32); + if (element_size == 16) { + PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 + + FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16; + PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 + + FloorDiv(FloorMod(i, 8), 4) * 8 + + FloorDiv(FloorMod(j, 8), 4) * 16; + return Fragment({i, j}, {idx}, thd, rep); + } else { + PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) + + FloorDiv(FloorMod(i, 16), 8) * 4 + + FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16; + PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) + + FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 + + FloorDiv(FloorMod(j, 8), 4) * 16; + return Fragment({i, j}, {idx}, thd, rep); + } +} + +Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, + const int warp_m, const int warp_n, + int element_size) { + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 32 == 0); + ICHECK(warp_n % 32 == 0); + auto base_layout = makeGemmFragment32x32(element_size); + auto warp_layout = + base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false); + auto block_layout = + warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true); + return block_layout; +} + +Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n) { + // assume not transposed + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 32 == 0); + ICHECK(block_k % 4 == 0); + // this is a special case + IterVar i = make_itervar("i", 32); + IterVar j = make_itervar("j", 4); + IterVar rep = make_itervar("rep", 2); + PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) + + FloorMod(i, 4) + 8 * rep; + PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4; + Fragment base_layout = Fragment({i, j}, {idx}, thd, rep); + auto warp_layout = + base_layout->Repeat({warp_m / 32, block_k / 4}, false, false); + auto block_layout = warp_layout->Replicate(block_n / warp_n) + ->Repeat({block_m / warp_m, 1}, true); + return block_layout; +} + +PrimExpr xor2x2(const PrimExpr &i, const PrimExpr &j) { + return FloorMod(i + j, 2); +} + +PrimExpr xor4x4(const PrimExpr &i, const PrimExpr &j) { + PrimExpr i0 = FloorMod(i, 2); + PrimExpr j0 = FloorMod(j, 2); + PrimExpr i1 = FloorDiv(i, 2); + PrimExpr j1 = FloorDiv(j, 2); + return 2 * xor2x2(i1, j1) + xor2x2(i0, j0); +} + +PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) { + PrimExpr i0 = FloorMod(i, 2); + PrimExpr j0 = FloorMod(j, 2); + PrimExpr i1 = FloorDiv(i, 2); + PrimExpr j1 = FloorDiv(j, 2); + return 2 * xor4x4(i1, j1) + xor2x2(i0, j0); +} + +// Layout swizzling for 32 bytes +Layout makeQuarterBankSwizzleLayout(int stride, int continuous, + int element_size) { + // Swizzle 1 bit + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + int vector_size = 128 / element_size; + ICHECK(stride % 8 == 0) << "stride=" << stride; + ICHECK(continuous % (vector_size * 2) == 0) + << "continuous=" << continuous << ", vector_size=" << vector_size; + PrimExpr ts = FloorDiv(i, 8); + PrimExpr s = FloorMod(i, 8); + PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 2); + PrimExpr c = FloorMod(FloorDiv(j, vector_size), 2); + PrimExpr vec = FloorMod(j, vector_size); + PrimExpr c_swizzle = xor2x2(c, FloorDiv(s, 4)); + PrimExpr index = vec + (c_swizzle + s * 2) * vector_size; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + +// Layout swizzling for 64 bytes +Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) { + // Swizzle 2 bit + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + int vector_size = 128 / element_size; + ICHECK(stride % 8 == 0) << "stride=" << stride; + ICHECK(continuous % (vector_size * 4) == 0) + << "continuous=" << continuous << ", vector_size=" << vector_size; + PrimExpr ts = FloorDiv(i, 8); + PrimExpr s = FloorMod(i, 8); + PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 4); + PrimExpr c = FloorMod(FloorDiv(j, vector_size), 4); + PrimExpr vec = FloorMod(j, vector_size); + PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2)); + PrimExpr index = vec + (c_swizzle + s * 4) * vector_size; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + +// Layout swizzling for 128 bytes +Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) { + // Swizzle 3 bit + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + int vector_size = 128 / element_size; + ICHECK(stride % 8 == 0) << "stride=" << stride; + ICHECK(continuous % (vector_size * 8) == 0) + << "continuous=" << continuous << ", vector_size=" << vector_size; + PrimExpr ts = FloorDiv(i, 8); + PrimExpr s = FloorMod(i, 8); + PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8); + PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8); + PrimExpr vec = FloorMod(j, vector_size); + PrimExpr c_swizzle = xor8x8(c, s); + PrimExpr index = vec + (c_swizzle + s * 8) * vector_size; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + +// Detail implementation please ref to +// bitblas::tl::mfma_layout::make_mfma_swizzle_layout +Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, + int kPack = 1) { + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + const int vecSize = (64 / element_size) * kPack; + const int innerDimLength = continuous; + const int typeWidthInBit = element_size; + + const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + + IterVar row = make_itervar("row", stride); + IterVar col = make_itervar("col", continuous); + PrimExpr phase = FloorMod(row / perPhase, maxPhase); + PrimExpr colOffSwizzled = ((col / vecSize) ^ phase) * vecSize; + PrimExpr colOffOrdered = FloorMod(col, vecSize); + PrimExpr colOff = colOffSwizzled + colOffOrdered; + + return Layout(Array{row, col}, {row, colOff}); +} + +Layout makeGemmABLayoutF64_Kinner(int stride, int continuous) { + // Swizzle<2, 0, 4> + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + PrimExpr tc = FloorDiv(j, 16); + PrimExpr ts = FloorDiv(i, 4); + PrimExpr c = FloorMod(j, 16); + PrimExpr s = FloorMod(i, 4); + PrimExpr swizzled_c = FloorDiv(c, 4) * 4 + xor4x4(FloorMod(c, 4), s); + PrimExpr index = swizzled_c + s * 16; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + +Layout makeGemmABLayoutF64_Kouter(int stride, int continuous) { + // Swizzle<2, 2, 2> + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + PrimExpr tc = FloorDiv(j, 16); + PrimExpr ts = FloorDiv(i, 4); + PrimExpr c = FloorMod(j, 16); + PrimExpr s = FloorMod(i, 4); + PrimExpr swizzled_c = FloorMod(c, 4) + xor4x4(FloorDiv(c, 4), s) * 4; + PrimExpr index = swizzled_c + s * 16; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + +// The Default Layout for Tensor Access +Layout makeGemmLayoutLinear(int stride, int continuous) { + IterVar i = make_itervar("i", stride); + IterVar j = make_itervar("j", continuous); + return Layout(Array{i, j}, {i * continuous + j}); +} + +Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) { + IterVar i = make_itervar("i", stride); + IterVar j = make_itervar("j", continuous); + int padded = continuous; + // Add 128 bits padding when the last dim is a multiple of 256 bits + if ((element_size * continuous) % 256 == 0) + padded += 128 / element_size; + return Layout(Array{i, j}, {i * padded + j}); +} + +Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) { + ICHECK(stride % 32 == 0 && continuous % 32 == 0); + IterVar i = make_itervar("i", stride); + IterVar j = make_itervar("j", continuous); + PrimExpr vec_contiguous_idx = FloorDiv(j, 4); + PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8); + + PrimExpr bit2 = + FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) + + FloorDiv(vec_strided_within_tile, 4), + 2); + PrimExpr bit1 = xor2x2(FloorDiv(FloorMod(i, 8), 4), + FloorDiv(FloorMod(vec_strided_within_tile, 4), 2)); + PrimExpr permuted_vec_contiguous = + FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1; + + PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 + + vec_contiguous_idx * stride * 4; + return Layout(Array{i, j}, {offset}); +} + +Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) { + ICHECK(stride % 4 == 0 && continuous % 64 == 0); + IterVar i = make_itervar("i", stride); + IterVar j = make_itervar("j", continuous); + PrimExpr vec_contiguous_idx = FloorDiv(j, 8); + PrimExpr vec_strided_idx = i; + PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8); + PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4); + PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8); + PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4); + + PrimExpr permuted_strided_within_tile = FloorDiv(tile_contiguous_residual, 2); + PrimExpr permuted_contiguous_within_tile = + FloorMod(tile_contiguous_residual, 2) * 4 + + xor4x4(tile_strided_residual, permuted_strided_within_tile); + + PrimExpr element_strided = + permuted_strided_within_tile + tile_strided_idx * 4; + PrimExpr element_contiguous = + FloorMod(j, 8) + + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8; + PrimExpr offset = element_strided * continuous + element_contiguous; + return Layout(Array{i, j}, {offset}); +} + +Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) { + ICHECK(stride % 4 == 0 && continuous % 64 == 0); + IterVar i = make_itervar("i", stride); + IterVar j = make_itervar("j", continuous); + PrimExpr vec_contiguous_idx = FloorDiv(j, 8); + PrimExpr vec_strided_idx = i; + PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8); + PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4); + PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8); + PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4); + + PrimExpr permuted_strided_within_tile = FloorMod(tile_contiguous_residual, 4); + PrimExpr permuted_contiguous_within_tile = + FloorDiv(tile_contiguous_residual, 4) * 4 + + xor4x4(tile_strided_residual, permuted_strided_within_tile); + + PrimExpr element_strided = + permuted_strided_within_tile + tile_strided_idx * 4; + PrimExpr element_contiguous = + FloorMod(j, 8) + + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8; + PrimExpr offset = element_strided * continuous + element_contiguous; + return Layout(Array{i, j}, {offset}); +} + +Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, + bool k_inner) { + if (k_inner && continuous % 32 == 0 && stride % 32 == 0) + return MakeGemmVoltaABLayoutCrosswise(stride, continuous); + if (is_a && continuous % 64 == 0 && stride % 4 == 0) + return MakeGemmVoltaALayoutCongruous(stride, continuous); + if (!is_a && continuous % 64 == 0 && stride % 4 == 0) + return MakeGemmVoltaBLayoutCongruous(stride, continuous); + return makeGemmABLayoutPadded(stride, continuous, 16); +} + +// ref: +// https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/tensor_op_multiplicand_sm75.h#L54 +// Although the four settings (T or NT) used distinct layouts in CUTLASS, they +// appeared to result in the same mem layout +Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous, + int elementsize, int crosswise) { + /// This layout is optimized for 128b accesses + static int const kAccessSize = 128; + int kCrosswise = crosswise; + + int kElementSize = elementsize; + int kElementsPerAccess = kAccessSize / kElementSize; + + /// Contiguous dimension of the tile shape matches one shared memory cache + /// line - 128B. For 128bit access size, it equals to 8 accesses. + int kTileShapeContiguous = 128 / (kAccessSize / 8); + + int kFactor = kTileShapeContiguous * kElementsPerAccess / kCrosswise; + + ICHECK(kFactor > 0) + << "kCrosswise should be no large than one shared memory cache line."; + + /// The strided dimension needs to be at least (WarpSize(32) / + /// kTileShapeContiguous) for a warp to access. To ensure conflict free + /// access, it also needs to be at least (kTileShapeContiguous / kFactor). + /// See comments below + /// Fundamental tile shape in units of vectors to guarantee bank conflict free + /// shared memory load/store. + /// For kFactor = 1, TileShape = <8, 8> + /// For kFactor > 1, TileShape = <8, 4> + int kTileShapeStride = + ((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous)) + ? (kTileShapeContiguous / kFactor) + : (32 / kTileShapeContiguous); + + const int kPartitionShapeContiguous = 4; + const int kPartitionShapeStride = 4; + + // NOTE: it's always row major for tl + IterVar i = make_itervar("i", mat_stride); + IterVar j = make_itervar("j", mat_continuous); + + PrimExpr vec_contiguous_idx = FloorDiv(j, kElementsPerAccess); + PrimExpr vec_strided_idx = FloorDiv(i, kFactor); + + // Compute the fundamental tile being accessed + PrimExpr tile_contiguous_idx = + FloorDiv(vec_contiguous_idx, FloorDiv(kTileShapeContiguous, kFactor)); + + PrimExpr tile_contiguous_residual = + FloorMod(vec_contiguous_idx, FloorDiv(kTileShapeContiguous, kFactor)) + + (FloorMod(i, kFactor) * FloorDiv(kTileShapeContiguous, kFactor)); + PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, kTileShapeStride); + + // Compute the 'partition' within the fundamental tile + PrimExpr partition_contiguous_idx = + FloorDiv(tile_contiguous_residual, kPartitionShapeContiguous); + PrimExpr partition_strided_idx = + FloorDiv(tile_strided_residual, kPartitionShapeStride); + + PrimExpr partition_contiguous_residual = + FloorMod(tile_contiguous_residual, kPartitionShapeContiguous); + PrimExpr partition_strided_residual = + FloorMod(tile_strided_residual, kPartitionShapeStride); + + // + // Then swizzle + // + + PrimExpr permuted_vec_contiguous_within_partition = xor4x4( + partition_contiguous_residual, FloorMod(partition_strided_residual, 4)); + + PrimExpr permuted_partition_contiguous_within_tile = + xor2x2(partition_contiguous_idx, FloorMod(partition_strided_idx, 2)); + + // + // Compute final element location + // + + PrimExpr element_contiguous = + (tile_contiguous_idx * kTileShapeContiguous + + permuted_partition_contiguous_within_tile * kPartitionShapeContiguous + + permuted_vec_contiguous_within_partition) * + kElementsPerAccess + + FloorMod(j, kElementsPerAccess); + + const PrimExpr &element_strided = vec_strided_idx; + + const int stride = mat_continuous; + + return Layout(Array{i, j}, + {element_contiguous + element_strided * stride * kFactor}); +} + +Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous, + int elementsize) { + int kCrosswise = std::min(mat_continuous, (1024 / elementsize)); + return makeTensorOpMultiplicand(mat_stride, mat_continuous, elementsize, + kCrosswise); +} + +/*! + * \brief Creates a memory layout for GEMM's A or B matrices. + * + * This function selects an appropriate memory layout based on the matrix + * dimensions, element size, continuity, and a k-factor. It aims to optimize + * memory access patterns, potentially using swizzling techniques or specialized + * layouts for different data types and hardware characteristics. + * + * \param mat_stride The leading dimension of the matrix (e.g., K for a + * row-major M x K matrix). This is the number of elements to skip to get to the + * same column in the next row (row-major) or to the same row in the next column + * (column-major). \param mat_continuous The length of the dimension stored + * contiguously in memory (e.g., K for a row-major M x K matrix, or M for a + * column-major M x K matrix). \param continuity The size of the dimension that + * is continuous from the perspective of memory bank access. This is used to + * select specific swizzling strategies. It might be the same as mat_continuous + * or different based on tiling or hardware details. + * \param element_size The size of each element in the matrix, in bits (e.g., 8, + * 16, 32, 64). \param k_inner Whether the K dimension is in the inner loop. + * selection, particularly for fp64 and int8 types. It often relates to how the + * K dimension of the GEMM (M x K * K x N) is handled or tiled. + * - For fp64 (element_size == 64): + * - k_inner == false often implies K is in the "outer" loop + * (e.g., KxN matrix). + * - k_inner == true often implies K is in the "inner" loop + * (e.g., NxK matrix). + * - For int8 (element_size == 8): + * - k_inner == false uses a padded layout. + * \return A Layout object representing the chosen memory layout. + */ +Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, + int element_size, bool k_inner) { + if (element_size == 64) { + if (!k_inner && continuity % 16 == 0) // float64 KxN + return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); + if (k_inner && continuity % 16 == 0) // float64 NxK + return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); + return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); + } + int vector_size = 128 / element_size; + if (!k_inner && element_size == 8) // int8 KxN + return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 8) == 0) + // return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, + // element_size); + return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 4) == 0) + return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else { + return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); + } +} + +Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, + int continuity, int element_size, bool k_inner) { + if (element_size == 64) { + if (!k_inner && continuity % 16 == 0) // float64 KxN + return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); + if (k_inner && continuity % 16 == 0) // float64 NxK + return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); + return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + } + int vector_size = 128 / element_size; + + if (mat_continuous % (vector_size * 8) == 0) + return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 4) == 0) + return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 2) == 0) + return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + else if (mat_continuous % vector_size == 0) + return makeGemmLayoutLinear(mat_stride, mat_continuous); + else + ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride + << ", continuous=" << mat_continuous + << ", element_size=" << element_size << ", k_inner=" << k_inner; +} + +Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, + int element_size, bool k_inner) { + if (element_size == 64) { + ICHECK(0) << "float64 on sm100 is not supported now"; + } + int vector_size = 128 / element_size; + if (mat_continuous % (vector_size * 8) == 0) + return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 4) == 0) + return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 2) == 0) + return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + else if (mat_continuous % vector_size == 0) + return makeGemmLayoutLinear(mat_stride, mat_continuous); + else + ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride + << ", continuous=" << mat_continuous + << ", element_size=" << element_size << ", k_inner=" << k_inner; + __builtin_unreachable(); // to prevent compiler warning +} + +Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, + int kPack) { + return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack); +} +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/layout.cc b/tilelang/original/src/layout/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..63d9c04011f080e81c9f202fa42e13457b16dcb5 --- /dev/null +++ b/tilelang/original/src/layout/layout.cc @@ -0,0 +1,846 @@ +/*! + * \file layout/layout.cc + * + */ + +#include "layout.h" +#include +#include + +#include +#include +#include + +#include "arith/pattern_match.h" +#include "tvm/node/functor.h" +#include "tvm/node/repr_printer.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +static Var getPlaceholder(const std::string &s) { + static std::unordered_map map; + if (map.find(s) == map.end()) { + map[s] = Var(s); + } + return map[s]; +} + +Var ReplicationPlaceholder() { return getPlaceholder("_rep"); } +Var InputPlaceholder(size_t idx) { + return getPlaceholder(std::string{'_', char('i' + idx)}); +} + +Map LayoutNode::getVarMap() const { + Map map; + for (size_t i = 0; i < InputDim(); i++) { + map.Set(InputPlaceholder(i), {0, input_size_[i]}); + } + return map; +} + +Map FragmentNode::getVarMap() const { + auto map = LayoutNode::getVarMap(); + map.Set(ReplicationPlaceholder(), {0, ReplicateExtent()}); + return map; +} + +LayoutNode::LayoutNode(Array input_size, + Array forward_index) { + input_size_ = input_size; + arith::Analyzer analyzer; + UpdateAnalyzer(&analyzer); + forward_index_ = forward_index.Map( + [&](const PrimExpr &e) { return analyzer.Simplify(e); }); +} + +Layout::Layout(Array forward_var, Array forward_index) { + Map vmap; + Array input_size; + for (size_t i = 0; i < forward_var.size(); i++) { + vmap.Set(forward_var[i]->var, InputPlaceholder(i)); + CHECK(is_zero(forward_var[i]->dom->min)); + input_size.push_back(forward_var[i]->dom->extent); + } + forward_index = + forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); + auto n = tvm::ffi::make_object(input_size, forward_index); + data_ = std::move(n); +} + +Layout::Layout(Array input_size, Array forward_index) { + auto n = tvm::ffi::make_object(input_size, forward_index); + data_ = std::move(n); +} + +void LayoutNode::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("input_size", &LayoutNode::input_size_) + .def_ro("forward_index", &LayoutNode::forward_index_) + .def("_DebugOutput", &LayoutNode::DebugOutput); +} + +void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const { + for (const auto &[var, dom] : getVarMap()) { + analyzer->Bind(var, dom); + } +} + +Array LayoutNode::GetForwardVars() const { + Array vars; + for (size_t i = 0; i < InputDim(); i++) { + vars.push_back(InputPlaceholder(i)); + } + return vars; +} + +Array LayoutNode::OutputShape() const { + Array ret(OutputDim(), 1); + arith::Analyzer analyzer; + UpdateAnalyzer(&analyzer); + for (size_t i = 0; i < ret.size(); i++) { + auto ist = analyzer.int_set(forward_index_[i] + 1); + if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) { + // Analyzer couldn't form an IntervalSet (e.g. bitwise ops). + // Fall back to ConstIntBound to derive a safe extent. + auto cib = analyzer.const_int_bound(forward_index_[i]); + if (cib->min_value != arith::ConstIntBound::kNegInf && + cib->max_value != arith::ConstIntBound::kPosInf && + cib->min_value >= 0) { + // extent = max - min + 1, using 64-bit integer literal + ret.Set(i, Integer(cib->max_value - cib->min_value + 1)); + } else { + // Last-resort conservative fallback to avoid OOB/crash + // Prefer to keep dimension from known input_size_ if available. + if (i < input_size_.size()) { + ret.Set(i, input_size_[i]); + } else { + ret.Set(i, Integer(1)); + } + } + } else { + ret.Set(i, ist.max()); + } + } + return ret; +} + +Array LayoutNode::Forward(const Array &vars) const { + if (vars.empty()) + return forward_index_; + ICHECK_GE(vars.size(), InputDim()); + + // Take the last InputDim() elements for transformation + Array transform_vars; + for (size_t i = vars.size() - InputDim(); i < vars.size(); i++) { + transform_vars.push_back(vars[i]); + } + + Map vmap; + for (size_t i = 0; i < InputDim(); i++) { + vmap.Set(InputPlaceholder(i), transform_vars[i]); + } + + Array transformed = forward_index_.Map( + [&](const PrimExpr &e) { return Substitute(e, vmap); }); + // Concatenate with the remaining elements from vars + Array result; + for (size_t i = 0; i < vars.size() - InputDim(); i++) { + result.push_back(vars[i]); + } + for (const auto &expr : transformed) { + result.push_back(expr); + } + + return result; +} + +Fragment FragmentNode::Repeat(const Array &repeats, + bool repeat_on_thread, + bool lower_dim_first) const { + ICHECK_EQ(repeats.size(), InputDim()); + Array new_input_size; + Map vmap; + for (size_t i = 0; i < InputDim(); i++) { + new_input_size.push_back(input_size_[i] * repeats[i]); + vmap.Set(InputPlaceholder(i), + FloorMod(InputPlaceholder(i), InputShape()[i])); + } + + PrimExpr repeats_index = 0, repeat_stride = 1; + if (lower_dim_first) { + for (int i = InputDim() - 1; i >= 0; i--) { + repeats_index += + repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]); + repeat_stride *= repeats[i]; + } + } else { + for (size_t i = 0; i < InputDim(); i++) { + repeats_index += + repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]); + repeat_stride *= repeats[i]; + } + } + + if (repeat_on_thread) { + PrimExpr thread_size = ThreadExtent(); + auto new_forward_index = forward_index_.Map( + [&](const PrimExpr &e) { return Substitute(e, vmap); }); + auto new_forward_thread = + Substitute(forward_thread_, vmap) + thread_size * repeats_index; + return Fragment(new_input_size, new_forward_index, new_forward_thread, + replicate_size_, std::nullopt); + } else { + ICHECK(OutputDim() == 1); + PrimExpr frag_len = OutputShape()[0]; + Array new_forward_index = {Substitute(forward_index_[0], vmap) + + frag_len * repeats_index}; + PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); + return Fragment(new_input_size, new_forward_index, new_forward_thread, + replicate_size_, std::nullopt); + } +} + +Fragment FragmentNode::Replicate(int repeats) const { + ICHECK(repeats >= 1); + Map vmap; + vmap.Set(ReplicationPlaceholder(), + FloorMod(ReplicationPlaceholder(), ReplicateExtent())); + PrimExpr new_forward_thread = + Substitute(forward_thread_, vmap) + + ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent()); + return Fragment(input_size_, forward_index_, new_forward_thread, + ReplicateExtent() * repeats, std::nullopt); +} + +Fragment FragmentNode::DeReplicate() const { + ICHECK(OutputDim() == 1); + arith::Analyzer analyzer; + UpdateAnalyzer(&analyzer); + int factor = 1; + auto rep_size = as_const_int(ReplicateExtent()); + auto idx_size = as_const_int(OutputShape()[0]); + if (rep_size && idx_size) { + factor = arith::ZeroAwareGCD(*rep_size, *idx_size); + } + if (factor == 1) + return tvm::ffi::GetRef(this); + + Map vmap; + vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor + + FloorMod(forward_index_[0], factor)); + PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); + Array new_forward_index = {FloorDiv(forward_index_[0], factor)}; + return Fragment(input_size_, new_forward_index, new_forward_thread, + int(*rep_size) / factor, std::nullopt); +} + +Fragment FragmentNode::BindThreadRange(Range thread_range) const { + auto n = tvm::ffi::make_object(*this); + n->thread_range_ = thread_range; + return Fragment(n); +} + +std::pair LayoutNode::InverseWithLevel() const { + arith::Analyzer analyzer; + auto collect_symbolic = [&](const Array &shape) { + Array symbolic_dims; + for (const auto &dim : shape) { + if (!as_const_int(dim)) { + symbolic_dims.push_back(dim); + } + } + return symbolic_dims; + }; + Array symbolic_dims = collect_symbolic(input_size_); + Array output_shape = OutputShape(); + symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(), + output_shape.end()); + symbolic_dims = collect_symbolic(symbolic_dims); + bool is_static_shape = symbolic_dims.empty(); + auto level = is_static_shape ? arith::IterMapLevel::Bijective + : arith::IterMapLevel::NoCheck; + if (!is_static_shape) { + // Runtime guards keep dynamic tails safe, so we allow NoCheck here and + // warn. + DLOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to " + "NoCheck; symbolic dims: " + << symbolic_dims; + } + arith::IterMapResult res = + arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); + if (!res->errors.empty()) { + std::ostringstream msg; + msg << "Layout " << DebugOutput() << " has errors: " << res->errors; + throw NormalizeIterException(msg.str()); + } + + auto outputs_shape = OutputShape(); + Array outputs; + for (size_t i = 0; i < OutputDim(); i++) { + outputs.push_back(InputPlaceholder(i)); + } + + auto inv = arith::InverseAffineIterMap(res->indices, outputs); + + Array backward_index; + for (size_t i = 0; i < InputDim(); i++) { + if (inv.find(InputPlaceholder(i)) != inv.end()) { + backward_index.push_back(inv[InputPlaceholder(i)]); + } else { + backward_index.push_back(0); + } + } + + return {Layout(outputs_shape, backward_index), level}; +} + +Layout LayoutNode::Reshape(const Array &shape, + arith::Analyzer *analyzer, + const PrimExpr rescale_num, + const PrimExpr rescale_den) const { + + // Fast path: if shape is the same, return the original layout + if (StructuralEqual()(InputShape(), shape)) { + return ffi::GetRef(this); + } + + // Step 1. Prove the product relation holds under rescale: + // prod(InputShape) * rescale_num == prod(shape) * rescale_den + PrimExpr input_shape_product = Integer(1); + for (const auto &dim : InputShape()) { + input_shape_product *= dim; + } + PrimExpr shape_product = Integer(1); + for (const auto &dim : shape) { + shape_product *= dim; + } + + // Use provided analyzer if present, otherwise a local fallback to avoid + // potential null dereference paths flagged by static analysis. + arith::Analyzer fallback_analyzer; + arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; + ICHECK(az->CanProveEqual(input_shape_product * rescale_num, + shape_product * rescale_den)) + << "InputShape() = " << InputShape() << " shape = " << shape + << ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den; + + // Step 2. Create new forward indices by reshaping + // For each dimension in the new shape, we create a placeholder variable + Array new_vars; + new_vars.reserve(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype()); + az->Bind(var, Range(0, shape[i])); + new_vars.push_back(var); + } + // Step 3. Compute the flat index from new shape indices + // flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn + PrimExpr flat_index = Integer(0); + for (size_t i = 0; i < shape.size(); ++i) { + PrimExpr stride = Integer(1); + for (size_t j = i + 1; j < shape.size(); ++j) { + stride = stride * shape[j]; + } + flat_index = flat_index + new_vars[i] * stride; + } + // Convert new flat index (in units of new elements) to the old flat index + // (in units of old elements) using the rational rescale factor. + // old_flat = floor((flat_index * rescale_den) / rescale_num) + PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num); + // Step 4. Convert flat index back to original shape indices + // For original shape [s0, s1, ..., sm]: + // i0 = flat_index // (s1 * s2 * ... * sm) + // i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm) + // ... + Array original_indices; + PrimExpr remaining = old_flat_index; + for (size_t i = 0; i < InputShape().size(); ++i) { + PrimExpr stride = Integer(1); + for (size_t j = i + 1; j < InputShape().size(); ++j) { + stride = stride * InputShape()[j]; + } + original_indices.push_back(floordiv(remaining, stride)); + remaining = floormod(remaining, stride); + } + // Step 5. Substitute original indices into forward_index_ + Array new_forward_index; + for (const auto &fwd_expr : forward_index_) { + PrimExpr substituted = fwd_expr; + // Replace each InputPlaceholder(i) with original_indices[i] + for (size_t i = 0; i < InputShape().size(); ++i) { + substituted = + Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}}); + } + new_forward_index.push_back(az->Simplify(substituted)); + } + for (size_t i = 0; i < new_vars.size(); ++i) { + new_forward_index = + Substitute(new_forward_index, {{new_vars[i], InputPlaceholder(i)}}); + } + return Layout(shape, new_forward_index); +} + +Layout FragmentNode::Reshape(const Array &shape, + arith::Analyzer *analyzer, + const PrimExpr rescale_num, + const PrimExpr rescale_den) const { + + // Fast path: identical input shape, return self + if (StructuralEqual()(InputShape(), shape)) { + return ffi::GetRef(this); + } + + // 1) Prove total number of elements remains the same + PrimExpr input_prod = Integer(1); + for (const auto &d : InputShape()) + input_prod *= d; + PrimExpr shape_prod = Integer(1); + for (const auto &d : shape) + shape_prod *= d; + + // Use provided analyzer if present, otherwise a local fallback. + arith::Analyzer fallback_analyzer; + arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; + ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den)) + << "InputShape() = " << InputShape() << " shape = " << shape + << ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den + << " input fragment layout is = " << DebugOutput(); + + // 2) Build flat index from new-shape indices + Array new_vars; + new_vars.reserve(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + // Cannot use InputPlaceholder(i) here, because it would cause name capture + // (variable capture) with InputPlaceholder(i) in upper scopes. Therefore, + // we must create a fresh variable here to avoid confusion when + // substituting. + auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype()); + az->Bind(var, Range(0, shape[i])); + new_vars.push_back(var); + } + + PrimExpr flat = Integer(0); + for (size_t i = 0; i < shape.size(); ++i) { + PrimExpr stride = Integer(1); + for (size_t j = i + 1; j < shape.size(); ++j) + stride = stride * shape[j]; + flat = flat + new_vars[i] * stride; + } + // Convert to old flat index units using the rational rescale factor. + // old_flat = floor((flat * rescale_den) / rescale_num) + PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num); + // 3) Recover original indices from flat index + Array orig_indices; + PrimExpr remain = old_flat; + for (size_t i = 0; i < InputShape().size(); ++i) { + PrimExpr stride = Integer(1); + for (size_t j = i + 1; j < InputShape().size(); ++j) + stride = stride * InputShape()[j]; + orig_indices.push_back(floordiv(remain, stride)); + remain = floormod(remain, stride); + } + // 4) Substitute old placeholders with expressions of new indices + Array new_forward_index; + for (const auto &e : forward_index_) { + PrimExpr cur = e; + for (size_t i = 0; i < InputShape().size(); ++i) { + cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}}); + } + cur = az->Simplify(cur); + new_forward_index.push_back(cur); + } + PrimExpr new_forward_thread = forward_thread_; + for (size_t i = 0; i < InputShape().size(); ++i) { + new_forward_thread = Substitute(new_forward_thread, + {{InputPlaceholder(i), orig_indices[i]}}); + } + new_forward_thread = az->Simplify(new_forward_thread); + for (size_t i = 0; i < new_vars.size(); ++i) { + auto var = new_vars[i]; + new_forward_index = + Substitute(new_forward_index, {{var, InputPlaceholder(i)}}); + new_forward_thread = + Substitute(new_forward_thread, {{var, InputPlaceholder(i)}}); + } + Fragment reshaped(shape, new_forward_index, new_forward_thread, + ReplicateExtent(), std::nullopt); + if (thread_range_.defined()) { + reshaped = reshaped->BindThreadRange(thread_range_); + } + return reshaped; +} + +Layout LayoutNode::Inverse() const { + auto inverse_result = InverseWithLevel(); + return std::move(inverse_result.first); +} + +PrimExpr infer_fragment_index(const Map &input_iters, + const PrimExpr &forward_thread, + arith::Analyzer *analyzer) { + Array splits = DivideUnusedIterators( + {forward_thread}, ToIterVars(input_iters), analyzer); + + Array split_without_rep; + for (const auto &split : splits) { + CHECK(split->source->source.as()); + if (split->source->source.as().value().same_as( + ReplicationPlaceholder())) + continue; + split_without_rep.push_back(split); + } + return MakeFlattenedExpression(split_without_rep); +} + +FragmentNode::FragmentNode(Array input_size, + Array forward_index, + PrimExpr forward_thread, PrimExpr replicate_size) { + input_size_ = input_size; + replicate_size_ = replicate_size; + arith::Analyzer analyzer; + UpdateAnalyzer(&analyzer); + forward_thread_ = analyzer.Simplify(forward_thread); + if (forward_index.empty()) { + forward_index = { + infer_fragment_index(getVarMap(), forward_thread_, &analyzer)}; + } + forward_index_ = forward_index.Map( + [&](const PrimExpr &e) { return analyzer.Simplify(e); }); +} + +Fragment::Fragment(Array forward_var, Array forward_index, + PrimExpr forward_thread, IterVar thread_replicate) { + Map vmap; + Array input_size; + PrimExpr replicate_size = 1; + for (size_t i = 0; i < forward_var.size(); i++) { + vmap.Set(forward_var[i]->var, InputPlaceholder(i)); + CHECK(is_zero(forward_var[i]->dom->min)); + input_size.push_back(forward_var[i]->dom->extent); + } + if (thread_replicate.defined()) { + ICHECK(is_zero(thread_replicate->dom->min)); + replicate_size = thread_replicate->dom->extent; + vmap.Set(thread_replicate->var, ReplicationPlaceholder()); + } + forward_index = + forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); + forward_thread = Substitute(forward_thread, vmap); + + auto n = tvm::ffi::make_object(input_size, forward_index, + forward_thread, replicate_size); + data_ = std::move(n); +} + +Fragment::Fragment(Array input_size, Array forward_index, + PrimExpr forward_thread, PrimExpr replicate_size, + Optional replicate_var) { + if (replicate_var.defined()) { + forward_thread = Substitute( + forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}}); + } + auto n = tvm::ffi::make_object(input_size, forward_index, + forward_thread, replicate_size); + data_ = std::move(n); +} + +Fragment Fragment::FullyReplicated(Array shape, + PrimExpr thread_extent) { + return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent, + std::nullopt); +} + +// which means the forward_thread is rep_var -> lambda i, rep: rep +bool FragmentNode::IsCompletedReplicated() const { + arith::Analyzer analyzer; + return ExprDeepEqual()(analyzer.Simplify(forward_thread_), + ReplicationPlaceholder()); +} + +arith::IterMapResult FragmentNode::DetectInjective() const { + // lei:To perform injective check, we need to reverse the layout + // and use surjective check, now we use bijective check for convenience + // can be relaxed in future + arith::Analyzer analyzer; + // Build a flat indices array: [forward_thread_, forward_index_[...]] + Array indices; + indices.push_back(forward_thread_); + for (const auto &e : forward_index_) { + indices.push_back(e); + } + + // Mirror Layout::InverseWithLevel(): if any participating shape is + // symbolic, relax to NoCheck and rely on runtime guards elsewhere. + auto collect_symbolic = [&](const Array &shape) { + Array symbolic_dims; + for (const auto &dim : shape) { + if (!as_const_int(dim)) { + symbolic_dims.push_back(dim); + } + } + return symbolic_dims; + }; + + Array symbolic_dims = collect_symbolic(InputShape()); + Array output_shape = OutputShape(); + symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(), + output_shape.end()); + // Also consider replicate size for fragments + if (!as_const_int(ReplicateExtent())) { + symbolic_dims.push_back(ReplicateExtent()); + } + symbolic_dims = collect_symbolic(symbolic_dims); + + bool is_static_shape = symbolic_dims.empty(); + auto level = is_static_shape ? arith::IterMapLevel::Bijective + : arith::IterMapLevel::NoCheck; + if (!is_static_shape) { + DLOG(WARNING) + << "Fragment::DetectInjective on symbolic layout, falling back to " + << "NoCheck; symbolic dims: " << symbolic_dims; + } + + return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); +} + +PrimExpr FragmentNode::ThreadExtent() const { + Array ret(OutputDim(), 1); + arith::Analyzer analyzer; + UpdateAnalyzer(&analyzer); + auto ist = analyzer.int_set(forward_thread_ + 1); + return ist.max(); +} + +Array FragmentNode::GetForwardVars() const { + Array vars; + if (*as_const_int(ReplicateExtent()) > 1) { + vars.push_back(ReplicationPlaceholder()); + } + for (size_t i = 0; i < InputDim(); i++) { + vars.push_back(InputPlaceholder(i)); + } + return vars; +} + +PrimExpr FragmentNode::ForwardThread(const Array &vars, + const Optional &rep_var) const { + Map vmap; + ICHECK_EQ(vars.size(), InputDim()); + for (size_t i = 0; i < InputDim(); i++) { + vmap.Set(InputPlaceholder(i), vars[i]); + } + if (rep_var.defined()) + vmap.Set(ReplicationPlaceholder(), rep_var.value()); + + return Substitute(forward_thread_, vmap); +} + +Layout FragmentNode::Inverse() const { + auto result = InverseWithLevel(); + return std::move(result.first); +} + +std::pair FragmentNode::InverseWithLevel() const { + auto input_size_copy = input_size_; + input_size_copy.push_back(ReplicateExtent()); + auto forward_index_copy = forward_index_; + forward_index_copy.push_back( + Substitute(forward_thread_, + {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}})); + auto fwd = Layout(input_size_copy, forward_index_copy); + return fwd->InverseWithLevel(); +} + +Fragment FragmentNode::CondenseReplicateVar() const { + arith::Analyzer analyzer; + auto input_iters = getVarMap(); + input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()}); + PrimExpr new_forward_thread; + IterVar new_thread_replicate; + std::tie(new_forward_thread, new_thread_replicate) = + CompressIterator(forward_thread_, ToIterVars(input_iters), + ReplicationPlaceholder(), &analyzer); + return Fragment(input_size_, forward_index_, new_forward_thread, + new_thread_replicate->dom->extent, new_thread_replicate->var); +} + +std::string LayoutNode::DebugOutput() const { + std::stringstream ss; + ss << "Layout(" << InputShape() << " -> " << OutputShape() + << ", transform: " << GetForwardVars() << " -> " << GetForwardIndex() + << ")"; + return ss.str(); +} + +std::string FragmentNode::DebugOutput() const { + std::stringstream ss; + ss << "Fragment(" << InputShape() << " -> " << OutputShape() + << ", replicate: " << ReplicateExtent() << ", thread: " << ThreadExtent() + << ", forward_thread: " << forward_thread_ + << ", forward_index: " << GetForwardIndex(); + if (thread_range_.defined()) { + ss << ", thread_range: " << thread_range_; + } + ss << ")"; + return ss.str(); +} + +bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const { + bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); + ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); + if (!skip_index) { + ret &= StructuralEqual()(this->forward_index_, other->forward_index_); + } + return ret; +} + +bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const { + // Fragment Layout Comparison can skip the index comparison + // when the output shape is the same, as we can do + // a[i, j] = b[j, i] in register level. + + bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); + if (!ret) { + // may be broadcast case + return true; + } + if (this->thread_range_.defined() && other->thread_range_.defined()) { + ret &= StructuralEqual()(this->thread_range_, other->thread_range_); + } + ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); + ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent()); + ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent()); + if (!skip_index) { + ret &= StructuralEqual()(this->forward_index_, other->forward_index_); + } + return ret; +} + +void FragmentNode::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("forward_thread", &FragmentNode::forward_thread_) + .def_ro("replicate_size", &FragmentNode::replicate_size_) + .def("_DebugOutput", &FragmentNode::DebugOutput); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef &obj, ReprPrinter *p) { + auto *node = static_cast(obj.get()); + p->stream << node->DebugOutput(); + }) + .set_dispatch([](const ObjectRef &obj, ReprPrinter *p) { + auto *node = static_cast(obj.get()); + p->stream << node->DebugOutput(); + }); + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tl.Layout", + [](PackedArgs args, Any *rv) { + *rv = Layout(args[0].cast>(), + args[1].cast>()); + }) + .def("tl.Layout_input_shape", + [](Layout layout) { return layout->InputShape(); }) + .def("tl.Layout_output_shape", + [](Layout layout) { return layout->OutputShape(); }) + .def("tl.Layout_inverse", [](Layout layout) { return layout->Inverse(); }) + .def("tl.Layout_index", + [](Layout layout) { return layout->GetForwardIndex(); }) + .def("tl.Layout_forward_vars", + [](Layout layout) { return layout->GetForwardVars(); }) + .def("tl.Layout_is_equal", + [](Layout layout, Layout other) { + const LayoutNode *other_node = other.as(); + return layout->IsEqual(other_node); + }) + .def_packed("tl.Fragment", + [](PackedArgs args, Any *rv) { + *rv = Fragment( + /*forward_var=*/args[0].cast>(), + /*forward_index=*/args[1].cast>(), + /*forward_thread=*/args[2].cast(), + /*thread_replicate=*/args[3].cast()); + }) + .def("tl.Fragment_is_equal", + [](Fragment fragment, Fragment other) { + const FragmentNode *other_node = other.as(); + return fragment->IsEqual(other_node); + }) + .def("tl.Fragment_thread_size", + [](Fragment fragment) { return fragment->ThreadExtent(); }) + .def("tl.Fragment_thread", + [](Fragment fragment) { return fragment->GetForwardThread(); }) + .def("tl.Fragment_repeat", + [](Fragment fragment, Array repeats, bool repeat_on_thread, + bool lower_dim_first) { + return fragment->Repeat(repeats, repeat_on_thread, + lower_dim_first); + }) + .def("tl.Fragment_replicate", + [](Fragment fragment, int repeats) { + return fragment->Replicate(repeats); + }) + .def("tl.Fragment_condense_rep_var", + [](Fragment fragment) { return fragment->CondenseReplicateVar(); }) + .def("tl.make_swizzled_layout", + [](int stride, int continuous, int element_size, bool k_inner, + bool allow_pad = true) { + if (allow_pad) { + return makeGemmABLayout(stride, continuous, continuous, + element_size, k_inner); + } else { + return makeGemmABLayoutHopper(stride, continuous, continuous, + element_size, k_inner); + } + }) + .def("tl.make_volta_swizzled_layout", + [](int stride, int mat_continuous, bool is_a, bool k_inner) { + return makeGemmVoltaABLayout(stride, mat_continuous, is_a, + k_inner); + }) + .def("tl.make_wgmma_swizzled_layout", + [](int stride, int mat_continuous, int continuity, int element_size, + bool k_inner) { + return makeGemmABLayoutHopper(stride, mat_continuous, continuity, + element_size, k_inner); + }) + .def("tl.make_tcgen05mma_swizzled_layout", + [](int stride, int mat_continuous, int continuity, int element_size, + bool k_inner) { + return makeGemmABLayoutSm100(stride, mat_continuous, continuity, + element_size, k_inner); + }) + .def("tl.make_full_bank_swizzled_layout", + [](int stride, int continuous, int element_size) { + return makeFullBankSwizzleLayout(stride, continuous, element_size); + }) + .def("tl.make_half_bank_swizzled_layout", + [](int stride, int continuous, int element_size) { + return makeHalfBankSwizzleLayout(stride, continuous, element_size); + }) + .def("tl.make_quarter_bank_swizzled_layout", + [](int stride, int continuous, int element_size) { + return makeQuarterBankSwizzleLayout(stride, continuous, + element_size); + }) + .def("tl.make_linear_layout", [](int stride, int continuous) { + return makeGemmLayoutLinear(stride, continuous); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + LayoutNode::RegisterReflection(); + FragmentNode::RegisterReflection(); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/layout.h b/tilelang/original/src/layout/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..9cd23905a20052f9cc96a5c6a8b6fdca2f926932 --- /dev/null +++ b/tilelang/original/src/layout/layout.h @@ -0,0 +1,269 @@ +/*! + * \file Layout.h + * + */ + +#ifndef TVM_TL_LAYOUT_LAYOUT_H_ +#define TVM_TL_LAYOUT_LAYOUT_H_ + +#include +#include +#include +#include +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +// Common layout-related exceptions +class LayoutConflictException : public std::exception { +public: + const char *what() const noexcept override { return msg_.c_str(); } + explicit LayoutConflictException(const std::string &msg) : msg_(msg) {} + +private: + std::string msg_; +}; + +class LoopLayoutInjectiveException : public std::exception { +public: + const char *what() const noexcept override { return msg_.c_str(); } + explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {} + +private: + std::string msg_; +}; + +class Layout; +class Fragment; + +class LayoutNode : public Object { +public: + LayoutNode() = default; + LayoutNode(Array input_size, Array forward_index); + + size_t InputDim() const { return input_size_.size(); } + + size_t OutputDim() const { return forward_index_.size(); } + + Array InputShape() const { return input_size_; } + + Array OutputShape() const; + + Array GetForwardIndex() const { return forward_index_; } + + virtual Array GetForwardVars() const; + + virtual Array Forward(const Array &vars) const; + + virtual Layout Inverse() const; + + // Reshape the layout to a new logical shape. When aliasing buffers of + // different dtypes, the element count may change while the underlying + // byte-size stays equal. Use rescale_num/rescale_den to represent the + // ratio between the old element size and the new element size in bytes. + // Specifically, define factor = rescale_num / rescale_den where: + // new_num_elems = old_num_elems * factor + // For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1. + // i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4. + virtual Layout Reshape(const Array &shape, + arith::Analyzer *analyzer, + const PrimExpr rescale_num = Integer(1), + const PrimExpr rescale_den = Integer(1)) const; + + virtual std::pair InverseWithLevel() const; + + virtual std::string DebugOutput() const; + + virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; + + static void RegisterReflection(); + TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object); + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = + kTVMFFISEqHashKindTreeNode; + +protected: + virtual Map getVarMap() const; + void UpdateAnalyzer(arith::Analyzer *analyzer) const; + Array forward_index_; + Array input_size_; +}; + +/*! + * \brief Layout reference class. + */ +class Layout : public ObjectRef { +public: + TVM_DLL Layout(Array forward_var, Array forward_index); + TVM_DLL Layout(Array input_size, Array forward_index); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode); +}; + +class FragmentNode : public LayoutNode { +public: + FragmentNode() = default; + FragmentNode(Array input_size, Array forward_index, + PrimExpr forward_thread, PrimExpr replicate_size); + + PrimExpr GetForwardThread() const { return forward_thread_; } + + Array GetForwardVars() const final; + + Layout Inverse() const final; + + Layout Reshape(const Array &shape, arith::Analyzer *analyzer, + const PrimExpr rescale_num = Integer(1), + const PrimExpr rescale_den = Integer(1)) const; + + std::pair InverseWithLevel() const final; + + PrimExpr ThreadExtent() const; + + PrimExpr ReplicateExtent() const { return replicate_size_; }; + + PrimExpr ForwardThread(const Array &vars, + const Optional &rep_var) const; + + Fragment Repeat(const Array &repeats, bool repeat_on_thread, + bool lower_dim_first = true) const; + + Fragment Replicate(int repeats) const; + + Fragment DeReplicate() const; + + Fragment CondenseReplicateVar() const; + + std::string DebugOutput() const final; + + Fragment BindThreadRange(Range thread_range) const; + + Range ThreadRange() const { return thread_range_; } + + bool IsEqual(const FragmentNode *other, bool skip_index = false) const; + + bool IsCompletedReplicated() const; + + arith::IterMapResult DetectInjective() const; + + static void RegisterReflection(); + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode); + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = + kTVMFFISEqHashKindTreeNode; + +protected: + Map getVarMap() const final; + Range thread_range_; + PrimExpr forward_thread_; + PrimExpr replicate_size_; +}; + +/*! + * \brief Fragment reference class. + */ +class Fragment : public Layout { +public: + TVM_DLL Fragment(Array forward_var, Array forward_index, + PrimExpr forward_thread, IterVar thread_replicate); + + TVM_DLL Fragment(Array input_size, Array forward_index, + PrimExpr forward_thread, PrimExpr replicate_size, + Optional replicate_var); + + /*! + * \brief Create a fully replicated fragment layout. + * + * A fully replicated fragment means all threads hold identical copies of the + * entire buffer. This is useful for index buffers or masks that need to be + * accessed uniformly across all threads. + * + * \param shape The shape of the buffer. + * \param thread_extent The number of threads. + * \return A Fragment where each thread has a complete copy of all elements. + */ + TVM_DLL static Fragment FullyReplicated(Array shape, + PrimExpr thread_extent); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode); +}; + +Var InputPlaceholder(size_t idx); +Var ReplicationPlaceholder(); +IterVar make_itervar(std::string name, PrimExpr dom); + +Fragment makeGemmFragment8x8(); +Fragment makeGemmFragment8x8Transposed(); +Fragment makeGemmFragmentC(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size); +Fragment makeGemmSparseFragmentC(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size); +Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size); +Fragment makeGemmFragmentCDCU(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size); +Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size); +Fragment makeGemmFragmentA(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n, const int element_size, + bool transposed = false); +Fragment makeGemmFragmentB(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n, bool transposed = false); + +Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n, const int element_size, + const int k_pack, bool transposed = false); + +// Default Memory Layout +Layout makeGemmLayoutLinear(int stride, int continuous); +Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); +Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, + int element_size, bool k_inner = true); +Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, + int continuity, int element_size, + bool k_inner = true); +Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, + int element_size, bool k_inner = true); +Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, + int kPack); + +Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, + const int warp_m, const int warp_n, + const int element_size); +Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, + const int block_k, const int warp_m, + const int warp_n); +Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, + bool k_inner = true); + +Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous, + int elementsize, int crosswise); +Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous, + int elementsize); + +Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size); +Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size); +Layout makeQuarterBankSwizzleLayout(int stride, int continuous, + int element_size); + +namespace attr { +// BlockAttr, Containing the layout for all the buffers in the block +constexpr const char *kLayoutMap = "layout_map"; +} // namespace attr + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_LAYOUT_LAYOUT_H_ diff --git a/tilelang/original/src/layout/swizzle.cc b/tilelang/original/src/layout/swizzle.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3222b9c0d12c1a6a03e3b184a888a95f2160ad9 --- /dev/null +++ b/tilelang/original/src/layout/swizzle.cc @@ -0,0 +1,109 @@ +/*! + * \file layout/swizzle.cc + * \brief Define swizzled layout + * + */ + +#include "swizzle.h" + +#include +#include +#include + +#include + +namespace tvm { +namespace tl { + +SwizzlePattern::SwizzlePattern(int bits, int base, int shift) + : bits_(bits), base_(base), shift_(shift) { + ICHECK(bits >= 0); + ICHECK(base >= 0); + ICHECK(shift >= 0); + ICHECK(shift >= bits); +} + +PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const { + int base = (1 << base_); + int mask = ((1 << bits_) - 1) << shift_; + PrimExpr high = FloorDiv(expr, base); + PrimExpr low = FloorMod(expr, base); + high = bitwise_xor(high, right_shift(bitwise_and(high, mask), shift_)); + return low + high * base; +} + +bool SwizzlePattern::operator==(const SwizzlePattern &other) const { + return std::tie(base_, bits_, shift_) == + std::tie(other.base_, other.bits_, other.shift_); +} + +SwizzledLayoutNode::SwizzledLayoutNode(Array input_size, + Array forward_index, + SwizzlePattern pattern) + : pattern_(pattern) { + input_size_ = input_size; + arith::Analyzer analyzer; + UpdateAnalyzer(&analyzer); + forward_index_ = forward_index.Map( + [&](const PrimExpr &e) { return analyzer.Simplify(e); }); +} + +Array SwizzledLayoutNode::Forward(const Array &vars) const { + auto expr_list = LayoutNode::Forward(vars); + auto expr = expr_list.back(); + expr_list.pop_back(); + expr_list.push_back(pattern_.swizzle(expr)); + return expr_list; +} + +std::string SwizzledLayoutNode::DebugOutput() const { + std::stringstream ss; + ss << LayoutNode::DebugOutput(); + ss << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " " + << pattern_.Shift(); + return ss.str(); +} + +Layout SwizzledLayoutNode::Inverse() const { + ICHECK(0) << "Not Implemented."; + return {}; +} + +bool SwizzledLayoutNode::IsEqual(const SwizzledLayoutNode *other, + bool skip_index) const { + return LayoutNode::IsEqual(other, skip_index) && pattern_ == other->pattern_; +} + +SwizzledLayout::SwizzledLayout(Array forward_var, + Array forward_index, + SwizzlePattern pattern) { + Map vmap; + Array input_size; + for (size_t i = 0; i < forward_var.size(); i++) { + vmap.Set(forward_var[i]->var, InputPlaceholder(i)); + CHECK(is_zero(forward_var[i]->dom->min)); + input_size.push_back(forward_var[i]->dom->extent); + } + forward_index = + forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); + + auto n = tvm::ffi::make_object(input_size, forward_index, + pattern); + data_ = std::move(n); +} + +SwizzledLayout::SwizzledLayout(Array input_size, + Array forward_index, + SwizzlePattern pattern) { + auto n = tvm::ffi::make_object(input_size, forward_index, + pattern); + data_ = std::move(n); +} + +void SwizzledLayoutNode::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/swizzle.h b/tilelang/original/src/layout/swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..b0bf5f1c962235a8af0e57fd67efc2a5592b86be --- /dev/null +++ b/tilelang/original/src/layout/swizzle.h @@ -0,0 +1,71 @@ +/*! + * \file swizzle.h + * \brief Define swizzled layout + * + */ + +#ifndef TVM_TL_LAYOUT_SWIZZLE_H_ +#define TVM_TL_LAYOUT_SWIZZLE_H_ + +#include "layout.h" + +namespace tvm { +namespace tl { + +/*! + * \brief Swizzle pattern + */ +class SwizzlePattern { +public: + SwizzlePattern() = default; + SwizzlePattern(int bits, int base, int shift); + PrimExpr swizzle(PrimExpr expr) const; + int Bits() const { return bits_; } + int Base() const { return base_; } + int Shift() const { return shift_; } + bool operator==(const SwizzlePattern &other) const; + +private: + int bits_; + int base_; + int shift_; +}; + +/*! + * \brief Layout with swizzle + */ +class SwizzledLayoutNode : public LayoutNode { +public: + SwizzledLayoutNode() = default; + SwizzledLayoutNode(Array input_size, Array forward_index, + SwizzlePattern pattern); + + Array Forward(const Array &vars) const final; + Layout Inverse() const final; + std::string DebugOutput() const final; + bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const; + static void RegisterReflection(); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode, + LayoutNode); + +private: + SwizzlePattern pattern_; +}; + +/*! + * \brief SwizzledLayout reference class. + */ +class SwizzledLayout : public Layout { +public: + TVM_DLL SwizzledLayout(Array forward_var, + Array forward_index, SwizzlePattern pattern); + TVM_DLL SwizzledLayout(Array input_size, + Array forward_index, SwizzlePattern pattern); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout, + SwizzledLayoutNode); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_LAYOUT_SWIZZLE_H_ diff --git a/tilelang/original/src/layout/tcgen05_layout.cc b/tilelang/original/src/layout/tcgen05_layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..64e0cdd646fc71442b2eb5b3c0deef92cf355ae2 --- /dev/null +++ b/tilelang/original/src/layout/tcgen05_layout.cc @@ -0,0 +1,111 @@ +/*! + * \file layout/tcgen05_layout.cc + * \brief Define Layout used in tcgen05.ld/st. + * + */ + +#include + +#include + +#include "layout.h" +#include "tcgen05_layout.h" + +namespace tvm { +namespace tl { + +static IterVar make_itervar(std::string name, Range dom) { + Var var = Var(name, dom->min->dtype); + return IterVar(dom, var, IterVarType::kDataPar); +} + +Tcgen05Meta getTcgen05Meta_32dp32b() { + constexpr int INST_WIDTH = 1; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{"tl::tcgen05_ld_32dp32bNx", + Fragment({inst_row, inst_col}, {inst_col}, {inst_row}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +Tcgen05Meta getTcgen05Meta_32dp64b() { + constexpr int INST_WIDTH = 2; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{ + "tl::tcgen05_ld_32dp64bNx", + Fragment({inst_row, inst_col}, {FloorDiv(FloorMod(inst_row, 32), 16)}, + {FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 + + FloorDiv(FloorMod(inst_row, 16), 8) + + FloorMod(inst_col, 2) * 2}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +Tcgen05Meta getTcgen05Meta_32dp128b() { + constexpr int INST_WIDTH = 4; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{ + "tl::tcgen05_ld_32dp128bNx", + Fragment({inst_row, inst_col}, {FloorDiv(FloorMod(inst_row, 32), 8)}, + {FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 + + FloorMod(inst_col, 4)}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +Tcgen05Meta getTcgen05Meta_32dp256b() { + constexpr int INST_WIDTH = 8; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{ + "tl::tcgen05_ld_32dp256bNx", + Fragment( + {inst_row, inst_col}, + {FloorMod(inst_col, 2) + FloorDiv(FloorMod(inst_row, 32), 8) * 2}, + {FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 + + FloorDiv(FloorMod(inst_col, 8), 2)}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +std::tuple +expandTcgen05Layout(const Tcgen05Meta &meta, int tmem_phy_col_extent, + int num_threads, Range row_dom, Range col_dom) { + static constexpr int WARPGROUP_SIZE = 128; + ICHECK(num_threads % WARPGROUP_SIZE == 0); + int num_wgs = num_threads / WARPGROUP_SIZE; + +#define FAIL_IF(cond) \ + if (cond) { \ + return {false, Fragment(), 0}; \ + } + + FAIL_IF(tmem_phy_col_extent % meta.width != 0); + int total_chunks = tmem_phy_col_extent / meta.width; + FAIL_IF(total_chunks % num_wgs != 0); // Otherwise the layout is not bijective + int num_chunks_each_wg = total_chunks / num_wgs; + int num_cols_each_wg = num_chunks_each_wg * meta.width; + int num_elems_each_thread_in_one_chunk = meta.width * 128 / WARPGROUP_SIZE; + + IterVar iter_row = make_itervar("row", row_dom); + IterVar iter_col = make_itervar("col", col_dom); + PrimExpr thread_idx = + meta.frag->ForwardThread({iter_row, FloorMod(iter_col, meta.width)}, + std::nullopt) + + FloorDiv(iter_col, num_cols_each_wg) * WARPGROUP_SIZE; + PrimExpr val_idx = + meta.frag->Forward({iter_row, FloorMod(iter_col, meta.width)})[0] + + FloorDiv(FloorMod(iter_col, num_cols_each_wg), meta.width) * + num_elems_each_thread_in_one_chunk; + + return {true, + Fragment({iter_row, iter_col}, {val_idx}, thread_idx, + make_itervar("rep", Range(0, 1))), + num_chunks_each_wg}; +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/tcgen05_layout.h b/tilelang/original/src/layout/tcgen05_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..8148d7077e0455c843684a2f5831ef208c24a2a7 --- /dev/null +++ b/tilelang/original/src/layout/tcgen05_layout.h @@ -0,0 +1,33 @@ +/*! + * \file layout/tcgen05_layout.cc + * + */ +#pragma once + +#include "layout.h" + +namespace tvm { +namespace tl { + +// A structure encapsulating the metadata for a particular tcgen05.ld/st +// instruction. +struct Tcgen05Meta { + std::string intrinsics_name; + Fragment frag; // Physical tmem coord |-> (thread_id, val_id) in fragment + int width; +}; + +// Obtain the metadata for tcgen05.ld/st instructions. +Tcgen05Meta getTcgen05Meta_32dp32b(); +Tcgen05Meta getTcgen05Meta_32dp64b(); +Tcgen05Meta getTcgen05Meta_32dp128b(); +Tcgen05Meta getTcgen05Meta_32dp256b(); + +// Expand a tcgen05 layout along thread_idx/value_idx (T/V) dimensions. +// Return {is_success, fragment, num_chunks_each_wg} +std::tuple +expandTcgen05Layout(const Tcgen05Meta &meta, int tmem_phy_col_extent, + int num_threads, Range row_dom, Range col_dom); + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/utils.cc b/tilelang/original/src/layout/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2a788b2483f0b4de2227aeffe759ee74e6a1e80 --- /dev/null +++ b/tilelang/original/src/layout/utils.cc @@ -0,0 +1,277 @@ +/*! + * \file layout/utils.cc + * \brief Some arith tools for layout & fragment inference + * + */ + +#include "utils.h" + +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace arith; + +bool CanProveDivisible(const PrimExpr &lhs, const PrimExpr &rhs) { + const auto *clhs = lhs.as(); + const auto *crhs = rhs.as(); + if (crhs && crhs->value == 0) { + return false; + } else if (clhs && crhs) { + return clhs->value % crhs->value == 0; + } + + return false; +} + +/*! + * \brief Collector that collects the outgoing split reference of each IterMark. + * + * These out-going splits can then be used to check if the iterators are + * independent. + */ +class IterMarkSplitCollector { +public: + // mark all IterMarks that are visited. + std::unordered_set visited_; + // each iter mark to its outgoing splits that are referenced. + std::unordered_map, ObjectPtrHash, + ObjectPtrEqual> + mark2splits_; + /*! + * \brief Collect all mark2splits recursively from indices. + * \param indices The iterator of interest. + */ + void Collect(const Array &indices) { + for (IterSumExpr sum_expr : indices) { + for (IterSplitExpr split : sum_expr->args) { + this->CollectInternal(split->source); + mark2splits_[split->source].push_back(split); + } + } + } + + void CollectInternal(const IterMark &mark) { + if (visited_.count(mark)) + return; + visited_.insert(mark); + if (auto *op = mark->source.as()) { + for (IterSplitExpr split : op->args) { + this->CollectInternal(split->source); + mark2splits_[split->source].push_back(split); + } + } + } +}; + +Array get_unused_iters(const IterMark &mark, + const std::vector &splits, + Analyzer *analyzer) { + PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); + std::vector used(splits.size(), false); + std::vector results; + size_t i = 0; + for (; i < splits.size();) { + size_t j = 0; + size_t lowest = splits.size(); + for (; j < splits.size(); ++j) { + if (used[j]) + continue; + if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor, + expected_lower_factor)) { + break; + } + if (lowest == splits.size() || + CanProveDivisible(splits[lowest]->lower_factor, + splits[j]->lower_factor)) { + lowest = j; + } + } + if (j == splits.size()) { + ICHECK(lowest != splits.size()); + ICHECK(CanProveDivisible(splits[lowest]->lower_factor, + expected_lower_factor)) + << " Cannot prove divisible for " << splits[lowest]->lower_factor + << " and " << expected_lower_factor; + results.emplace_back( + mark, expected_lower_factor, + FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1); + expected_lower_factor = splits[lowest]->lower_factor; + } else { + used[j] = true; + i++; + expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; + } + } + bool match_full_iter = + analyzer->CanProveEqual(expected_lower_factor, mark->extent); + if (!match_full_iter) { + results.emplace_back(mark, expected_lower_factor, + FloorDiv(mark->extent, expected_lower_factor), 1); + } + return results; +} + +// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator +// appears in fused forms across multiple index expressions. We first normalize +// every index into IterSumExpr, collect all splits per source Var, then +// consolidate them to avoid misclassifying a used split as unused. +Array DivideUnusedIterators(const Array &exprs, + const Array input_iters, + Analyzer *analyzer) { + auto iter_sum = exprs.Map([&](const auto &e) { + return NormalizeToIterSum(e, ToVMap(input_iters), analyzer); + }); + IterMarkSplitCollector collector; + collector.Collect(iter_sum); + Array results; + + for (const IterMark &mark : collector.visited_) { + if (!mark->source.as()) { + std::ostringstream oss; + oss << "Not a normalized iterator: " << mark; + throw NormalizeIterException(oss.str()); + } + } + + for (const IterVar &iter : input_iters) { + // Merge splits from all IterMark that share the same source Var as `iter`. + std::vector merged_splits; + for (const IterMark &mark : collector.visited_) { + auto vexpr = mark->source.as(); + if (vexpr && vexpr.value().same_as(iter->var)) { + auto it = collector.mark2splits_.find(mark); + if (it != collector.mark2splits_.end()) { + const auto &vec = it->second; + merged_splits.insert(merged_splits.end(), vec.begin(), vec.end()); + } + } + } + + if (!merged_splits.empty()) { + // Use a unified mark (Var + full extent) to compute the missing pieces + // so that fused usages are honored as "used" and not reintroduced. + IterMark unified_mark(iter->var, iter->dom->extent); + auto splits = get_unused_iters(unified_mark, merged_splits, analyzer); + // Put the small axis last for a flattened ordering. + results.insert(results.end(), splits.rbegin(), splits.rend()); + } else if (!is_one(iter->dom->extent)) { + auto mark = IterMark(iter->var, iter->dom->extent); + auto split = IterSplitExpr(mark, 1, iter->dom->extent, 1); + results.push_back(split); + } + } + return results; +} + +PrimExpr MakeFlattenedExpression(const Array &splits) { + Array lists; + PrimExpr scale = 1; + for (int i = splits.size() - 1; i >= 0; i--) { + auto scaled_split = arith::IterSplitExpr( + splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale); + lists.push_back(scaled_split); + scale *= splits[i]->extent; + } + return arith::NormalizeIterMapToExpr(arith::IterSumExpr(lists, 0)); +} + +class IterSumMutator { +public: + IterSumMutator(const Map &replace_map) + : replace_map_(replace_map) {} + + // override the original mutate function. + IterSumExpr Mutate(const IterSumExpr &iter_sum) { + Array args; + for (const auto &split : iter_sum->args) { + if (replace_map_.count(split)) { + args.push_back(replace_map_[split]); + } else { + auto split_ = IterSplitExpr(Mutate(split->source), split->lower_factor, + split->extent, split->scale); + args.push_back(split_); + } + } + return IterSumExpr(args, iter_sum->base); + } + + IterMark Mutate(const IterMark &mark) { + if (auto *op = mark->source.as()) { + return IterMark(Mutate(tvm::ffi::GetRef(op)), mark->extent); + } else { + return mark; + } + } + +private: + Map replace_map_; +}; + +std::pair CompressIterator(const PrimExpr &expr, + const Array input_iters, + const Var &var, + arith::Analyzer *analyzer) { + auto iter_sum = + arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer); + IterMarkSplitCollector collector; + collector.Collect({iter_sum}); + IterMark mark; + for (const IterMark &m : collector.visited_) { + ICHECK(m->source.as()) << "Not a normalized iterator: " << mark; + if (m->source.as().value().same_as(var)) { + mark = m; + break; + } + } + std::vector splits; + if (mark.defined()) { + splits = collector.mark2splits_[mark]; + } + + PrimExpr extent = 1; + for (const auto &split : splits) { + extent *= split->extent; + } + extent = analyzer->Simplify(extent); + + auto new_var = Var(var->name_hint, var->type_annotation); + auto new_iter_var = IterVar(Range(0, extent), new_var, IterVarType::kDataPar); + auto new_mark = IterMark(new_var, extent); + PrimExpr scale = 1; + Map replace_map; + for (const auto &split : splits) { + auto rescaled = + arith::IterSplitExpr(new_mark, scale, split->extent, split->scale); + replace_map.Set(split, rescaled); + scale *= split->extent; + } + + IterSumMutator mutator(replace_map); + PrimExpr reaplced = + analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum))); + + return {reaplced, new_iter_var}; +} + +Array ToIterVars(const Map &vmap) { + Array result; + for (const auto &[var, range] : vmap) { + result.push_back(IterVar(range, var, IterVarType::kDataPar)); + } + return result; +} + +Map ToVMap(const Array &ivs) { + Map result; + for (const auto &iv : ivs) { + result.Set(iv->var, iv->dom); + } + return result; +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/layout/utils.h b/tilelang/original/src/layout/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0f03a8617f1c08d323f8c2b1975be1dc40985bb5 --- /dev/null +++ b/tilelang/original/src/layout/utils.h @@ -0,0 +1,72 @@ +/*! + * \file layout/utils.h + * \brief Some arith tools for layout & fragment inference + * + */ + +#ifndef TVM_TL_LAYOUT_UTILS_H_ +#define TVM_TL_LAYOUT_UTILS_H_ + +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class NormalizeIterException : public std::exception { +public: + const char *what() const noexcept override { return msg_.c_str(); } + NormalizeIterException(const std::string &msg) : msg_(msg) {} + +private: + std::string msg_; +}; + +/*! + * \brief Collect the IterSplit that is not used in expr. + * + * If the expr is (x // 2) and x is in Range(4), + * than the result should be (x % 2) + */ +Array +DivideUnusedIterators(const Array &exprs, + const Array input_iters, + arith::Analyzer *analyzer); + +/*! + * \brief Compress the iterator var, remove the unused part of the var not + * present in the expr + * + * Returns the compressed IterVar as well as the Updated iter sum expression. + */ +std::pair CompressIterator(const PrimExpr &expr, + const Array input_iters, + const Var &var, + arith::Analyzer *analyzer); + +/*! + * \brief Convert the iter splits returned by DivideUnusedIterators into + * flattened expression + * + */ +PrimExpr MakeFlattenedExpression(const Array &splits); + +/*! + * \brief Convert an Array of IterVar to a Map object + * + */ +Map ToVMap(const Array &ivs); + +/*! + * \brief Convert a Map object to an Array of IterVar + * + */ +Array ToIterVars(const Map &vmap); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_LAYOUT_UTILS_H_ diff --git a/tilelang/original/src/op/atomic_add.cc b/tilelang/original/src/op/atomic_add.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fa0c6b5313a6f9fea392cda8735bb51a293bc01 --- /dev/null +++ b/tilelang/original/src/op/atomic_add.cc @@ -0,0 +1,550 @@ +/*! + * \file tl/op/atomic_add.cc + * + * Define element-wise operators. + */ + +#include "./atomic_add.h" +#include "utils.h" +#include +#include +#include + +#include "../target/utils.h" +#include "../transform/atomicadd_vectorize.h" +#include "../transform/common/loop_fusion_utils.h" +#include "../transform/common/loop_parallel_transform_utils.h" +#include "../transform/loop_partition.h" +#include "builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct an AtomicAdd operator from call arguments and a buffer map. + * + * Builds the internal AtomicAddNode, extracts the source and destination + * regions and their backing Buffers from the first two region-style expressions + * in `args` (BufferLoad/BufferRegion), and stores them along with their + * ranges. If a third argument is provided, it is interpreted as an integer + * immediate and stored as the node's coalesced width. + * + * @param args Call-style PrimExprs where: + * - args[0] is the source region call, + * - args[1] is the destination region call, + * - args[2] (optional) is an IntImm specifying coalesced width. + * Notes: + * - The constructor checks that args[0] and args[1] are region-compatible. + * - The constructed node is stored in this->data_. + */ +AtomicAdd::AtomicAdd(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + if (args.size() >= 3) { + node->use_tma = Downcast(args[2]); + } + node->memory_order = IntImm(0); + if (args.size() >= 4) { + node->memory_order = Downcast(args[3]); + } + if (args.size() >= 5) { + node->coalesced_width = Downcast(args[4]); + } + data_ = std::move(node); +} + +/** + * @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator. + * + * Produces a new AtomicAddNode object copied from this node. If this node has + * an associated ParallelOp (par_op_), the parallel op is cloned and attached to + * the new node so the cloned operator preserves parallelization state. + * + * @return TileOperator A TileOperator owning the cloned AtomicAddNode. + */ +TileOperator AtomicAddNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + if (par_op_.defined()) { + op->par_op_ = Downcast(par_op_->Clone()); + } + return AtomicAdd(op); +} + +/** + * @brief Create data-parallel iteration variables for non-singleton dimensions + * of the source. + * + * Constructs an Array of IterVar corresponding to each dimension in `src_range` + * whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a + * Var named sequentially ("i", "j", "k", ...) with the same dtype as the + * extent, and type IterVarType::kDataPar. The ordering of returned itervars + * matches the order of dimensions in `src_range`. + * + * @return Array Iteration variables for all non-singleton extents in + * `src_range`. + */ +Array AtomicAddNode::MakeIterVars() const { + Array loop_vars; + size_t idx = 0; + for (size_t i = 0; i < src_range.size(); i++) { + if (is_one(src_range[i]->extent)) + continue; + Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); + idx++; + loop_vars.push_back( + {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); + } + return loop_vars; +} + +// ivs: itervars returned by MakeIterVars() +/** + * @brief Build index expressions for either source or destination from loop + * iter vars. + * + * Given a list of iteration variables that correspond to the non-singleton + * extents of the selected region (source when src_dst == 0, destination when + * src_dst == 1), return an array of index expressions matching the full rank of + * that region. For dimensions with extent == 1, the corresponding index is the + * range's minimum; otherwise the index is `min + ivar`. + * + * @param ivs Iteration variables in order for all non-singleton dimensions of + * the chosen region. + * @param src_dst Selects which region to index: 0 for source (src_range), 1 for + * destination (dst_range). + * @return Array Index expressions for every dimension of the selected + * region, in original dimension order. + * + * @note The function checks that the number of provided iter vars equals the + * number of non-singleton extents; it will abort (ICHECK) if they differ. + */ +Array AtomicAddNode::MakeIndices(const Array &ivs, + int src_dst) const { + Array indices; + Array ranges = src_dst == 0 ? src_range : dst_range; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + indices.push_back(ranges[i]->min); + else { + indices.push_back(ranges[i]->min + ivs[idx]->var); + idx++; + } + } + ICHECK(idx == ivs.size()) + << "idx = " << idx << ", ivs.size() = " << ivs.size() + << "src name = " << src->name << ", dst name = " << dst->name; + return indices; +} + +std::pair, PrimExpr> +AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { + Array indices; + Array ranges = src_dst == 0 ? src_range : dst_range; + PrimExpr size = 1; + for (size_t i = 0; i < ranges.size(); i++) { + indices.push_back(ranges[i]->min); + size *= ranges[i]->extent; + } + return {indices, size}; +} + +/** + * @brief Build a combined bound-check predicate for indexed access. + * + * Constructs an AND'd predicate ensuring each non-singleton index (derived from + * `ivs`) stays within [0, extent) for the selected operand (source when + * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen + * range list this produces two conditions: + * - range.min + iv >= 0 + * - range.min + iv < extent + * + * Conditions that the analyzer can prove (with symbolic bounds) are omitted. + * If no uncertain conditions remain, an empty PrimExpr is returned. + * + * Note: the function ICHECKs that `extents.size()` equals the number of ranges + * for the selected operand. + * + * @param ivs Iteration variables corresponding to non-singleton extents (order + * matches the non-unit ranges of the chosen operand). + * @param extents Per-dimension upper bounds to check against; must have the + * same size as the selected range list. + * @param src_dst Selects which ranges to validate: 0 => `src_range`, else + * `dst_range`. + * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or + * an empty PrimExpr when no checks are required. + */ +PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, + int src_dst) const { + Array ranges = src_dst == 0 ? src_range : dst_range; + Array cond_list; + ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + continue; + PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i]; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + cond = ranges[i]->min + ivs[idx]->var >= 0; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + idx++; + } + if (cond_list.empty()) + return {}; + else { + PrimExpr cond = cond_list[0]; + for (size_t i = 1; i < cond_list.size(); i++) + cond = And(cond, cond_list[i]); + return cond; + } +} + +/** + * @brief Build a SIMT-style loop nest that performs element-wise atomic + * additions from src to dst. + * + * Constructs a nested loop (parallelized per iter var) that loads a value from + * the source buffer, optionally casts it to the destination dtype, and performs + * an extern atomic add into the destination buffer address. For scalar + * (zero-dimensional) operations a trivial serial For with a single BufferStore + * is returned. + * + * The method: + * - Creates iter vars for all non-singleton extents and binds them into the + * provided analyzer. + * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch). + * - Computes indexed accesses and emits optional bound predicates; + * out-of-bounds accesses are masked to zero when predicates are uncertain. + * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value), + * src_value)` call wrapped in an Evaluate statement. + * - Wraps the body with a parallel For at each loop level. If `coalesced_width` + * is defined it is attached as the "coalesced_width" annotation on each loop. + * + * Note: This function mutates the analyzer binding state by binding loop + * variables and may fail via ICHECK if internal assumptions about shapes are + * violated. + * + * @return A nested For loop (parallel loops) implementing the atomic-add + * kernel. For scalar cases a serial For of extent 1 is returned. + */ +For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { + Array loop_vars = MakeIterVars(); + bool is_scalar = loop_vars.empty(); + if (is_scalar) { + return For(Var("i"), 0, 1, ForKind::kSerial, + BufferStore(dst, BufferLoad(src, {0}), {0})); + } + + for (const auto &iv : loop_vars) + analyzer->Bind(iv->var, iv->dom); + + ICHECK(loop_vars.size() <= src_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", src_range.size() = " << src_range.size() << ", src = " << src->name + << ", dst = " << dst->name; + + ICHECK(loop_vars.size() <= dst_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name + << ", dst = " << dst->name; + + Array src_indices = MakeIndices(loop_vars, 0); + Array dst_indices = MakeIndices(loop_vars, 1); + + Array new_args; + + // Optional bounds predicates for src and dst + PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); + PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); + + // Load source value and cast to dst dtype if needed + PrimExpr src_value = BufferLoad(src, src_indices); + if (src->dtype != dst->dtype) + src_value = Cast(dst->dtype, src_value); + + // Build a pointer to destination element using tvm_access_ptr + PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(dst, dst_indices)}); + + new_args.push_back(dst_ptr); + new_args.push_back(src_value); + new_args.push_back(memory_order); + + Call atomicadd_call = + tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args); + + Stmt body = tvm::tir::Evaluate(atomicadd_call); + + for (int i = loop_vars.size() - 1; i >= 0; i--) { + Map annotations = {}; + if (coalesced_width.defined()) { + annotations.Set("coalesced_width", coalesced_width); + } + + body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, + ForKind::kParallel, body, std::nullopt, annotations); + } + return Downcast(body); +} + +/** + * @brief Infer and return the layout map for the atomic add operator. + * + * Constructs a cached ParallelOp (by building the SIMT loop) if not already + * present, validates that local.fragment layouts for src and dst match when + * both are provided, and then delegates layout inference to the underlying + * ParallelOp. + * + * @param T Layout inference inputs, including an optional mapping of buffers to + * layouts. + * @param level Inference strictness level. + * @return LayoutMap The inferred layout mapping for buffers used by this + * operator. + * + * @note This method mutates the AtomicAddNode by creating and storing a + * ParallelOp on first invocation. + * @throws If both src and dst have layouts in `local.fragment` and their + * fragment layouts differ, an ICHECK failure is raised with diagnostic output. + */ +LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (T.layout_map.count(src) && T.layout_map.count(dst)) { + if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") { + const FragmentNode *src_layout = T.layout_map[src].as(); + const FragmentNode *dst_layout = T.layout_map[dst].as(); + if (src_layout && dst_layout) { + ICHECK(src_layout->IsEqual(dst_layout, true)) + << "Get different layout for " << src << " and " << dst + << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the layout"; + } + } + } + return {}; +} + +/** + * @brief Lower the atomic-add top-level operator into a parallel, vectorized + * TIR loop. + * + * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs + * layout inference at multiple levels, partitions the root loop by the provided + * thread variable, vectorizes the thread loop, and returns the final + * (optionally predicate-guarded) statement. + * + * The lowering pipeline: + * - Build the SIMT loop via MakeSIMTLoop. + * - Fuse parallel loops into a single For and wrap as a ParallelOp. + * - Run layout inference at kCommon, kStrict, and kFree levels using fields + * from `T`. + * - Obtain the loop layout, partition the root loop with PartitionLoop by + * `T.thread_var`. + * - Vectorize the partitioned thread loop via VectorizeLoop. + * - If the ParallelOp produced a predicate for `T.thread_var`, return an + * IfThenElse that guards the vectorized loop with that predicate; otherwise + * return the vectorized loop. + * + * @param T Lowering context whose fields are used: + * - T.target: target architecture for layout inference and lowering + * decisions. + * - T.thread_var: the Var used to partition the outer loop for thread-level + * parallelism. + * - T.thread_bounds: bounds associated with the thread dimension (used during + * partitioning). + * - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used + * during InferLayout. + * @param analyzer Analyzer used for symbolic reasoning during partitioning and + * folding (omitted from detailed param docs as a common analysis utility). + * @return Stmt A lowered TIR statement representing the parallelized and + * vectorized atomic-add. + */ +Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + Target target = T.target; + if (use_tma->value != 0) { + Array src_indices, dst_indices; + PrimExpr src_size, dst_size; + std::tie(src_indices, src_size) = ReturnIndicesAndSize(0); + std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1); + ICHECK(analyzer->CanProveEqual(src_size, dst_size)) + << "src_size = " << src_size << ", dst_size = " << dst_size; + BufferLoad src_node = BufferLoad(src, src_indices); + BufferLoad dst_node = BufferLoad(dst, dst_indices); + Call address_of_src = + Call(DataType::Handle(), builtin::address_of(), {src_node}); + Call address_of_dst = + Call(DataType::Handle(), builtin::address_of(), {dst_node}); + + int need_reduce = 1; + int eviction_policy = 0; + auto body = Evaluate(Call(DataType::Handle(), tma_store(), + {address_of_src, address_of_dst, + ceildiv(src_size * src->dtype.bits(), 8), + need_reduce, eviction_policy})); + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body); + } + auto simt_loop = MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + auto transformed_loop = + Downcast(ParallelLoopTransformer::Substitute(fused_loop)); + + auto GetArchInt = [&](const Target &tgt) -> int { + int arch_int = 0; + if (auto s = tgt->GetAttr("arch")) { + std::string arch = s.value(); + if (arch.rfind("sm_", 0) == 0) + arch_int = std::stoi(arch.substr(3)); + } + return arch_int; + }; + + struct AtomicLoopNestCollector : tir::StmtExprVisitor { + Array loop_vars; + Map> indice_map; + std::unordered_set writes; + arith::Analyzer analyzer; + + void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); } + + void VisitStmt_(const ForNode *op) final { + if (op->kind == ForKind::kParallel) { + loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var, + IterVarType::kDataPar)); + } + analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.scope() == "local.fragment") { + indice_map.Set(op->buffer, op->indices); + writes.insert(op->buffer); + } + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode *op) final { + if (op->buffer.scope() == "local.fragment") { + indice_map.Set(op->buffer, op->indices); + } + StmtExprVisitor::VisitExpr_(op); + } + }; + + auto ComputeLoopLayoutFromBuffer = + [&](const Buffer &buf, const Array &indices, + const LayoutMap &layout_map, const Range &thread_bounds, + const Array &loop_vars) -> Fragment { + Fragment src = layout_map[buf].as().value(); + Var rep; + auto rep_iter = + IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar); + PrimExpr fth = src->ForwardThread(indices, rep); + fth = analyzer->Simplify(fth); + Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter) + ->BindThreadRange(thread_bounds); + return out; + }; + + struct AtomicInferResult { + Fragment loop_layout; + Optional predicate; + }; + + auto AtomicAddInferLayout = + [&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult { + AtomicLoopNestCollector C; + C.Run(loop); + Optional read_src; + int best_rank = -1; + for (auto kv : C.indice_map) { + const Buffer &buf = kv.first; + if (buf.scope() != "local.fragment") + continue; + if (!args.layout_map.count(buf)) + continue; + int rank = static_cast(kv.second.size()); + if (rank > best_rank) { + best_rank = rank; + read_src = buf; + } + } + AtomicAddVectorizePlanner planner; + int sm = GetArchInt(target); + auto plan = planner.Plan(loop, sm); + int vec = std::max(plan.vector_size, 1); + if (auto cw = loop->annotations.Get("coalesced_width")) { + if (const auto *imm = cw->as()) { + int expected = imm->value; + ICHECK_GT(expected, 0); + ICHECK(vec % expected == 0) + << "vector_size " << vec << " not divisible by coalesced_width " + << expected; + vec = expected; + } else { + LOG(FATAL) << "coalesced_width should be IntImmNode."; + } + } + PrimExpr total = 1; + for (Stmt s = loop; s.as().has_value(); s = s.as().value()->body) + total = total * s.as().value()->extent; + PrimExpr denom = args.thread_bounds->extent * vec; + while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) { + vec >>= 1; + denom = args.thread_bounds->extent * vec; + } + if (vec < 1) + vec = 1; + Fragment loop_layout; + if (read_src) { + loop_layout = ComputeLoopLayoutFromBuffer( + read_src.value(), C.indice_map[read_src.value()], args.layout_map, + args.thread_bounds, C.loop_vars); + } else { + const For &remapped = loop; + loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds); + } + + Optional pred; + if (plan.dynamic && plan.condition.defined()) { + pred = plan.condition; + } + DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec + << " loop_layout=" << loop_layout->DebugOutput(); + return {loop_layout, pred}; + }; + + auto ret = AtomicAddInferLayout(transformed_loop, + {T.target, T.thread_bounds, T.layout_map, + analyzer, false, T.buffer_remap}); + Fragment loop_layout = ret.loop_layout; + auto thread_loop = + PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); + auto vectorized_thread_loop = + VectorizeAtomicAdd(thread_loop, GetArchInt(target)); + return vectorized_thread_loop; +} + +TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); } + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/atomic_add.h b/tilelang/original/src/op/atomic_add.h new file mode 100644 index 0000000000000000000000000000000000000000..c6beb70eb03a37eac3a5a8505808bf5ede736f6a --- /dev/null +++ b/tilelang/original/src/op/atomic_add.h @@ -0,0 +1,75 @@ +/*! + * \file tl/op/atomic_add.h + * \brief Atomic addition operations for concurrent memory updates + */ + +#ifndef TVM_TL_OP_ATOMIC_ADD_H_ +#define TVM_TL_OP_ATOMIC_ADD_H_ + +#include "operator.h" +#include "parallel.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/// Node class for atomic addition operations +class AtomicAddNode : public TileOperatorNode { +public: + Buffer src, dst; ///< Source and destination buffers + Array src_range, + dst_range; ///< Access ranges for source and destination + IntImm use_tma; ///< Whether to use TMA for memory operations + IntImm coalesced_width; ///< Width for memory coalescing optimization + IntImm memory_order; ///< Memory order for atomic operations + + mutable ParallelOp par_op_; ///< Associated parallel operation + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode, + TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + + static const Op &Get(); + TileOperator Clone() const; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &AtomicAddNode::src) + .def_ro("dst", &AtomicAddNode::dst) + .def_ro("src_range", &AtomicAddNode::src_range) + .def_ro("dst_range", &AtomicAddNode::dst_range) + .def_ro("use_tma", &AtomicAddNode::use_tma) + .def_ro("coalesced_width", &AtomicAddNode::coalesced_width) + .def_ro("memory_order", &AtomicAddNode::memory_order); + } + +protected: + /// Create SIMT-style parallel loop structure + For MakeSIMTLoop(arith::Analyzer *analyzer) const; + /// Generate iteration variables for loop nest + Array MakeIterVars() const; + /// Generate buffer indices from iteration variables + Array MakeIndices(const Array &ivs, int src_dst) const; + /// Return buffer indices and size + std::pair, PrimExpr> ReturnIndicesAndSize(int src_dst) const; + /// Create boundary predicate for memory safety + PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, + Array extents, int src_dst) const; +}; + +/// Wrapper class for atomic addition operations +class AtomicAdd : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator, + AtomicAddNode); + TVM_DLL AtomicAdd(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_ATOMIC_ADD_H_ diff --git a/tilelang/original/src/op/builtin.cc b/tilelang/original/src/op/builtin.cc new file mode 100644 index 0000000000000000000000000000000000000000..f68bee0b7975b767e89ce0d224c64fe54a53a989 --- /dev/null +++ b/tilelang/original/src/op/builtin.cc @@ -0,0 +1,385 @@ +/*! + * \file tl/op/builtin.cc + * \brief Builtin intrinsics. + * + */ + +#include "builtin.h" + +#include +#include +#include + +#include "../target/cuda.h" +#include "../target/utils.h" + +namespace tvm { +namespace tl { + +TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String); +TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array); + +DataType cuTensorMapType() { return DataType::UInt(8, 128); } + +#define TIR_DEFINE_TL_BUILTIN(OpName) \ + const Op &OpName() { \ + static const Op &op = Op::Get("tl." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tl." #OpName) \ + .set_attr("TScriptPrinterName", #OpName) + +// fast math related op +TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +// high precision with IEEE-compliant +TIR_DEFINE_TL_BUILTIN(ieee_add).set_num_inputs(3).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_sub).set_num_inputs(3).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_mul).set_num_inputs(3).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_fmaf).set_num_inputs(4).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_frcp).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_fsqrt) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_frsqrt) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(create_tma_im2col_descriptor) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(get_mbarrier) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(tma_load).set_num_inputs(-1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_load_im2col) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_fence_barrier_init) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(mbarrier_wait_parity) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_wgmma_ss) + .set_num_inputs(15) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs) + .set_num_inputs(15) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss) + .set_num_inputs(14) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts) + .set_num_inputs(13) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70) + .set_num_inputs(13) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_stmatrix) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(fence_proxy_async) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store_arrive) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store_wait) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(set_max_nreg) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(no_set_max_nreg) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warpgroup_arrive) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warpgroup_commit_batch) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warpgroup_wait) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(get_lane_idx) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(get_warp_idx) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(get_warp_group_idx) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(wait_wgmma) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(sync_grid).set_num_inputs(0).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(loop_break) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tl_gemm).set_num_inputs(4).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tl_gemm_sp) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tvm_mfma).set_num_inputs(12).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tvm_mmac).set_num_inputs(12).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tvm_mfma_store) + .set_num_inputs(6) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma) + .set_num_inputs(12) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma_store) + .set_num_inputs(6) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor) + .set_num_inputs(7) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(device_assert) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(device_assert_with_msg) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_sum) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_max) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_min) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +// __ldg(BufferLoad | Buffer, idx?) -> value +// Treat as a pure call that returns the loaded value. +TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/builtin.h b/tilelang/original/src/op/builtin.h new file mode 100644 index 0000000000000000000000000000000000000000..9647789202c3c1d6a93951ecec2d72a6c4c4d729 --- /dev/null +++ b/tilelang/original/src/op/builtin.h @@ -0,0 +1,624 @@ +/*! + * \file tl/op/builtin.h + * \brief Builtin intrinsics. + * + */ + +#ifndef TVM_TL_OP_BUILTIN_H_ +#define TVM_TL_OP_BUILTIN_H_ + +#include "operator.h" +#include + +namespace tvm { +/*! + * \brief Create the TVM intrinsic that initializes a PTX fence barrier. + * + * Initializes a PTX fence-style barrier used to coordinate asynchronous memory + * operations (for example, TMA/TMA_STORE). Returns the Op representing this + * intrinsic for use in TIR lowering and code generation. + * + */ +namespace tl { + +namespace attr { +static constexpr const char *kSafeValueMap = "safe_value_map"; +static constexpr const char *kWarpSpecializationScope = + "kWarpSpecializationScope"; +static constexpr const char *kCustomWarpSpecialization = + "kCustomWarpSpecialization"; +static constexpr const char *kLocalVarInit = "tl.local_var_init"; +// A PrimFunc-level attribute carrying a list of handle Vars +// that must NOT be marked with the restrict qualifier in codegen. +// Type: Array +static constexpr const char *kNonRestrictParams = "tl.non_restrict_params"; +} // namespace attr + +static constexpr const char *kDebugMergeSharedMemoryAllocations = + "tl.debug_merge_shared_memory_allocations"; +static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; +static constexpr const char *kDisableSafeMemoryLegalize = + "tl.disable_safe_memory_legalize"; +static constexpr const char *kDisableWarpSpecialized = + "tl.disable_warp_specialized"; +static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; +static constexpr const char *kEnableAggressiveSharedMemoryMerge = + "tl.enable_aggressive_shared_memory_merge"; +static constexpr const char *kDisableFastMath = "tl.disable_fast_math"; +static constexpr const char *kEnableFastMath = "tl.enable_fast_math"; +static constexpr const char *kPtxasRegisterUsageLevel = + "tl.ptxas_register_usage_level"; +static constexpr const char *kEnablePTXASVerboseOutput = + "tl.enable_ptxas_verbose_output"; +static constexpr const char *kDisableVectorize256 = "tl.disable_vectorize_256"; +static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; +static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; +static constexpr const char *kStorageRewriteDetectInplace = + "tl.storage_rewrite_detect_inplace"; +static constexpr const char *kLayoutVisualizationEnable = + "tl.layout_visualization_enable"; +static constexpr const char *kLayoutVisualizationFormats = + "tl.layout_visualization_formats"; +static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags"; + +/*! + * \brief Whether to disable thread storage synchronization + * + * When enabled, disables the automatic insertion of thread synchronization + * barriers (e.g., __syncthreads()) for shared memory access coordination. + * This can be useful for performance optimization in cases where manual + * synchronization is preferred or when synchronization is not needed. + * + * kDisableThreadStorageSync = "tl.disable_thread_storage_sync" + * + */ +static constexpr const char *kDisableThreadStorageSync = + "tl.disable_thread_storage_sync"; + +/*! + * \brief Force inline Let bindings during simplification. + * + * kForceLetInline = "tl.force_let_inline" + * + */ +static constexpr const char *kForceLetInline = "tl.force_let_inline"; + +/*! + * \brief Get the type of the CUDA tensor map + * + * DataType cuTensorMapType() + * + */ +DataType cuTensorMapType(); + +// fast math related op +// __exp(x) - fast exponential +TVM_DLL const Op &__exp(); +// __exp10(x) - fast base-10 exponential +TVM_DLL const Op &__exp10(); +// __log(x) - fast natural logarithm +TVM_DLL const Op &__log(); +// __log2(x) - fast base-2 logarithm +TVM_DLL const Op &__log2(); +// __log10(x) - fast base-10 logarithm +TVM_DLL const Op &__log10(); +// __tan(x) - fast tangent +TVM_DLL const Op &__tan(); +// __cos(x) - fast cosine +TVM_DLL const Op &__cos(); +// __sin(x) - fast sine +TVM_DLL const Op &__sin(); + +// high precision with IEEE-compliant. +// ieee_add(x, y, rounding_mode) - IEEE-compliant addition +TVM_DLL const Op &ieee_add(); +// ieee_sub(x, y, rounding_mode) - IEEE-compliant subtraction +TVM_DLL const Op &ieee_sub(); +// ieee_mul(x, y, rounding_mode) - IEEE-compliant multiplication +TVM_DLL const Op &ieee_mul(); +// ieee_fmaf(x, y, z, rounding_mode) - IEEE-compliant fused multiply-add +TVM_DLL const Op &ieee_fmaf(); +// ieee_frcp(x, rounding_mode) - IEEE-compliant reciprocal +TVM_DLL const Op &ieee_frcp(); +// ieee_fsqrt(x, rounding_mode) - IEEE-compliant square root +TVM_DLL const Op &ieee_fsqrt(); +// ieee_frsqrt(x) - IEEE-compliant reciprocal square root (rn only) +TVM_DLL const Op &ieee_frsqrt(); +// ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division +TVM_DLL const Op &ieee_fdiv(); + +// random op +TVM_DLL const Op &rng_init(); +TVM_DLL const Op &rng_rand(); + +/*! + * \brief tvm intrinsics for TMADescriptor creation for tiled load + * + * CuTensorMap* create_tma_descriptor(data_type, rank, global_addr, + * global_shape..., global_stride..., smem_box..., smem_stride..., interleave, + * swizzle, l2_promotion, oob_fill) + * + */ +TVM_DLL const Op &create_tma_descriptor(); + +/*! + * \brief tvm intrinsics for TMADescriptor creation for image to column load + * + * CuTensorMap* create_tma_im2col_descriptor(data_type, rank, global_addr, + * global_shape..., global_stride..., elem_stride..., lower_corner..., + * upper_corner..., smme_box_pixel, smem_box_channel, interleave, swizzle, + * l2_promotion, oob_fill) + * + */ +TVM_DLL const Op &create_tma_im2col_descriptor(); + +/*! + * \brief Create a list of mbarrier with num_threads + * + * create_list_of_mbarrier(num_threads0, num_threads1, ...) + * + */ +TVM_DLL const Op &create_list_of_mbarrier(); + +/*! + * \brief Get the mbarrier with barrier_id + * + * int64_t* GetMBarrier(barrier_id) + * + */ +TVM_DLL const Op &get_mbarrier(); + +/*! + * \brief tvm intrinsics for loading data from global tensor descriptor to + * shared memory + * + * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...) + * + */ +TVM_DLL const Op &tma_load(); + +/*! + * \brief tvm intrinsics for loading image from global tensor to columns in + * shared memory + * + * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ..., + * image_offset, ...) + * + */ +TVM_DLL const Op &tma_load_im2col(); + +/*! + * \brief tvm intrinsics for storing data from shared memory to global tensor + * descriptor + * + * tma_store(descriptor, smem_data, coord_0, coord_1, ...) + * + */ +TVM_DLL const Op &tma_store(); + +/*! + * \brief tvm intrinsics for barrier initialization fence + * + * ptx_fence_barrier_init() + * + */ +const Op &ptx_fence_barrier_init(); + +/*! + * \brief tvm intrinsics for mbarrier wait with parity bit + * + * mbarrier_wait_parity(mbarrier, parity) + * + */ +TVM_DLL const Op &mbarrier_wait_parity(); + +/*! + * \brief tvm intrinsics for mbarrier expect tx + * + * mbarrier_expect_tx(mbarrier, transaction_bytes) + * + */ +TVM_DLL const Op &mbarrier_expect_tx(); + +/*! + * \brief tvm intrinsic for ptx tensor core wgmma instructions. + * + * void ptx_wgmma_ss(StringImm accum_dtype, StringImm wgmma_prefix, bool + * a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm + * b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr + * A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool + * scale_out, bool scale_in_a, bool scale_in_b); + */ +TVM_DLL const Op &ptx_wgmma_ss(); + +/*! + * \brief tvm intrinsics for ptx tensor core wgmma instructions. + * + * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, + * bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, + * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var + * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, + * bool scale_in_a, bool scale_in_b); + */ +TVM_DLL const Op &ptx_wgmma_rs(); + +/*! + * \brief tvm intrinsic for tcgen05 mma shared-shared instructions. + */ +TVM_DLL const Op &ptx_tcgen05_mma_ss(); + +/*! + * \brief tvm intrinsic for tcgen05 mma tensor-shared instructions. + */ +TVM_DLL const Op &ptx_tcgen05_mma_ts(); + +/*! + * \brief tvm intrinsics for initializing tensor memory + * + * ptx_init_tensor_memory(tmem_buffer, num_cols) + * + */ +TVM_DLL const Op &ptx_init_tensor_memory(); + +/*! + * \brief tvm intrinsics for deallocating tensor memory + * + * tmem_deallocate(tmem_buffer) + * + */ +TVM_DLL const Op &ptx_deallocate_tensor_memory(); + +/*! + * \brief tvm intrinsic for ptx tensor core mma instructions on SM70. + * + * void ptx_mma_sm70(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index, bool saturate); + */ +TVM_DLL const Op &ptx_mma_sm70(); + +/*! + * \brief tvm intrinsics for ldmatrix + * + * ptx_ldmatrix(transposed, num, shared_addr, local_addr) + * + */ +TVM_DLL const Op &ptx_ldmatrix(); + +/*! + * \brief tvm intrinsics for stmatrix + * + * ptx_ldmatrix(transposed, num, shared_addr, int32_values...) + * + */ +TVM_DLL const Op &ptx_stmatrix(); + +/*! + * \brief tvm intrinsic for ptx async copy barrier using + * cp.async.mbarrier.arrive.noinc + * + * This op is used to represent a ptx async copy barrier operation in tilelang. + */ +TVM_DLL const Op &ptx_cp_async_barrier_noinc(); + +/*! + * \brief Pack two b16 value into a b32 value + * + * int32 pack_b16(b16_value, b16_value) + * + */ +TVM_DLL const Op &pack_b16(); + +/*! + * \brief Issue a shared memory fence for async operations + * + * FenceProxyAsync() + * + */ +TVM_DLL const Op &fence_proxy_async(); + +/*! + * \brief Indicate arrival of warp issuing TMA_STORE + * + * tma_store_arrive() + * + */ +TVM_DLL const Op &tma_store_arrive(); + +/*! + * \brief Wait for TMA_STORE to finish + * + * tma_store_wait() + * + */ +TVM_DLL const Op &tma_store_wait(); + +/*! + * \brief Set reg hint for warp-specialized branched + * + * SetMaxNRegInc(num_reg, is_inc) + * + */ +TVM_DLL const Op &set_max_nreg(); + +/*! + * \brief No set reg hint for warp-specialized branched + * + * no_set_max_nreg() + * + */ +TVM_DLL const Op &no_set_max_nreg(); + +/*! + * \brief Arrive at a warpgroup fence for WGMMA sequences + * + * warpgroup_arrive() + * + */ +TVM_DLL const Op &warpgroup_arrive(); + +/*! + * \brief Commit the current warpgroup batch for WGMMA sequences + * + * warpgroup_commit_batch() + * + */ +TVM_DLL const Op &warpgroup_commit_batch(); + +/*! + * \brief Wait for the warpgroup batch identified by num_mma + * + * warpgroup_wait(num_mma) + * + */ +TVM_DLL const Op &warpgroup_wait(); + +/*! + * \brief Fence accumulator operand registers for upcoming WGMMA operations + * + * warpgroup_fence_operand(dtype, ptr, offset, num_regs) + * + */ +TVM_DLL const Op &warpgroup_fence_operand(); + +/*! + * \brief Return the canonical lane index for the calling thread. + * + * get_lane_idx([warp_size]) + * + */ +TVM_DLL const Op &get_lane_idx(); + +/*! + * \brief Return the canonical warp index, assuming converged threads. + * + * get_warp_idx_sync([warp_size]) + * + */ +TVM_DLL const Op &get_warp_idx_sync(); + +/*! + * \brief Return the canonical warp index without synchronizing the warp. + * + * get_warp_idx([warp_size]) + * + */ +TVM_DLL const Op &get_warp_idx(); + +/*! + * \brief Return the canonical warp group index for converged threads. + * + * get_warp_group_idx([warp_size, warps_per_group]) + * + */ +TVM_DLL const Op &get_warp_group_idx(); + +/*! + * \brief Wait the previous wgmma to finish + * + * wait_wgmma(num_mma) + * + */ +TVM_DLL const Op &wait_wgmma(); + +/*! + * \brief Synchronize all threads in a grid + * + * sync_grid() + * + */ +TVM_DLL const Op &sync_grid(); + +/*! + * \brief tvm intrinsic for loop continue + * + * loop_break() + * + */ +TVM_DLL const Op &loop_break(); + +/*! + * \brief tvm intrinsic for amd matrix core mfma instructions. + * + * void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index); + */ +TVM_DLL const Op &tvm_mfma(); + +/*! + * \brief tvm intrinsic for amd matrix core mmac instructions. + * + * void tvm_mmac(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index); + */ +TVM_DLL const Op &tvm_mmac(); + +/*! + * \brief tvm intrinsic for storing the result of AMD MFMA into a destination + * pointer. + * + * There is no real instruction that does that, but we want to hide + * details of complex index manipulation behind this intrinsic to simplify TIR + * lowering passes (e.g. LowerWarpMemory) like cuda ptx backend does. + * + * void tvm_mfma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr + * src_offset, Var dst_stride); + */ +TVM_DLL const Op &tvm_mfma_store(); + +/*! + * \brief tvm intrinsic for amd rdna matrix core instructions. + * + * void tvm_rdna_wmma(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index); + */ +TVM_DLL const Op &tvm_rdna_wmma(); + +/*! + * \brief tvm intrinsic for storing the result of AMD RDNA WMMA into a + * destination pointer. + * + * There is no real instruction that does that, but we want to hide + * details of complex index manipulation behind this intrinsic to simplify TIR + * lowering passes (e.g. LowerWarpMemory) like cuda ptx backend does. + * + * void tvm_rdna_wmma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr + * src_offset, Var dst_stride); + */ +TVM_DLL const Op &tvm_rdna_wmma_store(); + +/*! + * \brief tilelang intrinsic for general matrix multiplication (GEMM). + * + * This op is used to represent a generic GEMM operation in tilelang. + */ +TVM_DLL const Op &tl_gemm(); + +/*! + * \brief tilelang intrinsic for sparse matrix multiplication (GEMM with + * sparsity). + * + * This op is used to represent a sparse GEMM operation in tilelang. + */ +TVM_DLL const Op &tl_gemm_sp(); + +/*! + * \brief tilelang intrinsic for shuffle elect. + * + * This op is used to represent a shuffle elect operation in tilelang. + */ +TVM_DLL const Op &tl_shuffle_elect(); + +/*! + * \brief tilelang intrinsic for initializing a descriptor buffer for + * wgmma/utcmma. + * + * This op is used to represent a descriptor initialization operation in + * tilelang. + */ +TVM_DLL const Op &initialize_wgmma_descriptor(); + +/*! + * \brief tilelang intrinsic for initializing a descriptor buffer for + * tcgen05 mma. + */ +TVM_DLL const Op &initialize_tcgen05_descriptor(); + +/*! + * \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive. + * + * This op wraps the device-side arrive used to signal completion of MMA work + * to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive. + */ +TVM_DLL const Op &tcgen05_mma_arrive(); + +/*! + * \brief tilelang intrinsic for setting the start address of a descriptor + * buffer for wgmma/utcmma. + * + * This op is used to represent a descriptor start address setting operation in + * tilelang. + */ + +TVM_DLL const Op &increase_descriptor_offset(); + +/*! + * \brief tilelang intrinsic for element-wise atomic addition. + * + * This op is used to represent an element-wise atomic add operation in + * tilelang. + */ +TVM_DLL const Op &atomicadd_elem_op(); + +/*! + * \brief tilelang intrinsic for assert on device. + * + * This op is used to represent an assert on device + */ +TVM_DLL const Op &device_assert(); + +/*! + * \brief tilelang intrinsic for assert on device with additional message. + * + * This op is used to represent an assert on device with additional message. + */ +TVM_DLL const Op &device_assert_with_msg(); + +/*! + * \brief tilelang intrinsic for warp reduction sum. + */ +TVM_DLL const Op &warp_reduce_sum(); + +/*! + * \brief tilelang intrinsic for warp reduction max. + */ +TVM_DLL const Op &warp_reduce_max(); + +/*! + * \brief tilelang intrinsic for warp reduction min. + */ +TVM_DLL const Op &warp_reduce_min(); + +/*! + * \brief tilelang intrinsic for warp reduction bitand. + */ +TVM_DLL const Op &warp_reduce_bitand(); + +/*! + * \brief tilelang intrinsic for warp reduction bitor. + */ +TVM_DLL const Op &warp_reduce_bitor(); + +/*! + * \brief tilelang intrinsic for CUDA read-only cache load (__ldg). + * + * This op allows users to explicitly request a non-coherent cached load + * from global memory on CUDA by emitting `__ldg(&ptr[idx])` for 32-bit + * element types on supported architectures. It provides a direct way to + * leverage the read-only data cache for performance-sensitive loads when + * the compiler cannot infer `const __restrict__` automatically. + * + * Usage from TVMScript: + * y[i] = T.__ldg(x[i]) + * + * The op takes one argument preferred as a BufferLoad identifying the + * source element; alternatively, backends may support passing a Buffer and + * index expression. + */ +TVM_DLL const Op &__ldg(); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_BUILTIN_H_ diff --git a/tilelang/original/src/op/copy.cc b/tilelang/original/src/op/copy.cc new file mode 100644 index 0000000000000000000000000000000000000000..066a09b105a8292a685bff47261b73a4196a2052 --- /dev/null +++ b/tilelang/original/src/op/copy.cc @@ -0,0 +1,2128 @@ +/*! + * \file tl/op/copy.cc + * \brief Define copy operator for various memory transfer strategies (Normal, + * Bulk/TMA, LDSM/STSM) and lowering logic for GPU code generation. + * + * This module is part of TVM TensorIR's Tensor Layout (TL) operations, + * implementing memory copy operations that can target CPUs or GPUs with + * optimization for different instructions like bulk copy, matrix load/store, + * and Hopper's new TMA (Tensor Memory Accelerator). + */ + +#include "copy.h" +#include "../layout/tcgen05_layout.h" +#include "../target/utils.h" +#include "../transform/common/loop_fusion_utils.h" +#include "../transform/common/loop_parallel_transform_utils.h" +#include "../transform/loop_partition.h" +#include "../transform/loop_vectorize.h" +#include "utils.h" + +#include "../target/cuda.h" +#include "../target/utils.h" +#include "builtin.h" +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Helper to map TVM's DataType to CUDA's CUtensorMapDataType enum value. + * This function converts TVM data types to CUDA tensor map data types for TMA + * operations. + */ +static int to_CUtensorMapDataType(DataType dtype) { + CUtensorMapDataType tp; + if (dtype.is_float()) { + switch (dtype.bits()) { + case 64: + tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + break; + case 32: + tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + break; + case 8: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + ICHECK(0) << dtype; + } + } else if (dtype.is_bfloat16()) { + tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (dtype.is_float8()) { + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if (dtype.is_int()) { + switch (dtype.bits()) { + case 64: + tp = CU_TENSOR_MAP_DATA_TYPE_INT64; + break; + case 32: + tp = CU_TENSOR_MAP_DATA_TYPE_INT32; + break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 8: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + ICHECK(0) << dtype; + } + } else if (dtype.is_uint()) { + switch (dtype.bits()) { + case 64: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT64; + break; + case 32: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 8: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + ICHECK(0) << dtype; + } + } else { + ICHECK(0) << dtype; + } + return static_cast(tp); +} + +/*! + * \brief Utility function to reverse an array. + * This is commonly used to convert between row-major and column-major layouts. + */ +template static Array ReverseArray(Array array) { + return Array{array.rbegin(), array.rend()}; +} + +/*! + * \brief Construct a Copy operator node from call arguments and a buffer map. + * + * This constructor parses the first two entries of `args` as regions + * (BufferLoad/BufferRegion), extracts their Buffers and Ranges, and stores + * them on the newly created CopyNode. It also + * reads optional arguments: + * - args[2] (IntImm): coalesced width (stored only if > 0), + * - args[3] (Bool): disable TMA lowering flag, + * - args[4] (IntImm): eviction policy. + * + * Preconditions: + * - `args` must contain at least two region-compatible PrimExpr entries + * (BufferLoad/BufferRegion); ICHECK will fail otherwise. + * + * @param args Array of PrimExpr where: + * - args[0] is the source Region call, + * - args[1] is the destination Region call, + * - optional args[2..4] are coalesced width, disable_tma, and eviction + * policy. + */ +Copy::Copy(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + if (args.size() >= 3) { + auto coalesced_width = Downcast(args[2]); + if (coalesced_width->value > 0) { + node->coalesced_width = coalesced_width; + } + } + if (args.size() >= 4) { + node->disable_tma = Downcast(args[3]); + } + if (args.size() >= 5) { + node->eviction_policy = args[4].as()->value; + } + data_ = std::move(node); +} + +/** + * @brief Create a shallow clone of this CopyNode as a TileOperator. + * + * Produces a new CopyNode object copy-constructed from this node. If a parallel + * sub-operation (par_op_) is present, the sub-operation is cloned as well and + * attached to the new node. The returned value is a TileOperator wrapper + * around the newly created node. + * + * @return TileOperator A TileOperator owning the cloned CopyNode. + */ +TileOperator CopyNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + if (par_op_.defined()) { + op->par_op_ = Downcast(par_op_->Clone()); + } + return Copy(op); +} + +/*! + * \brief Create iterator variables for the copy operation. + * This function creates iteration variables for dimensions that have extent + * > 1. \return Array of IterVar representing the iterator variables for the + * copy operation. + */ +Array CopyNode::MakeIterVars() const { + // Choose the range set from the lowest-level memory scope between src and + // dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment + // (fragment) + auto scope_level = [](const Buffer &b) -> int { + String s = b.scope(); + if (s == "local.fragment" || s == "local") + return 2; + if (s == "shared" || s == "shared.dyn" || s == "shared.tmem") + return 1; + // default to global level for unknown scopes + return 0; + }; + + int src_level = scope_level(src); + int dst_level = scope_level(dst); + bool base_is_src = (src_level >= dst_level); + const Array &base_ranges = base_is_src ? src_range : dst_range; + + // Sanity check: when switching away from the original (src_range), + // ensure the chosen base ranges are not provably smaller than the original + // per dimension. This guards against generating undersized loop domains. + // Improved logic: use two pointers to traverse both base_ranges and + // src_range, skipping dimensions with extent == 1. The number of non-1 + // extents must match. + arith::Analyzer analyzer; + + size_t base_dim = 0, src_dim = 0; + while (base_dim < base_ranges.size() && src_dim < src_range.size()) { + // Skip base extents that are 1 + while (base_dim < base_ranges.size() && + is_one(base_ranges[base_dim]->extent)) { + ++base_dim; + } + // Skip src extents that are 1 + while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) { + ++src_dim; + } + // Both indices now at non-1, or at end + if (base_dim < base_ranges.size() && src_dim < src_range.size()) { + PrimExpr base_ext = base_ranges[base_dim]->extent; + PrimExpr src_ext = src_range[src_dim]->extent; + // Only fail if base extent is provably smaller than src extent + if (analyzer.CanProve(base_ext < src_ext)) { + std::ostringstream oss; + oss << "Selected loop range is smaller than original src range at " + "matched non-1 dimension: " + << "base(extent=" << base_ext + << ", scope=" << (base_is_src ? src.scope() : dst.scope()) + << ", min=" << base_ranges[base_dim]->min + << ", base_dim=" << base_dim << ") < src(extent=" << src_ext + << ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim + << ", scope=" << src.scope() << ") for src=" << src->name + << ", dst=" << dst->name << "\n"; + oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n"; + oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n"; + oss << "base_ranges[" << base_dim + << "]: min=" << base_ranges[base_dim]->min + << ", extent=" << base_ext << "\n"; + oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min + << ", extent=" << src_ext << "\n"; + LOG(FATAL) << oss.str(); + } + ++base_dim; + ++src_dim; + } + } + + // Any remaining unmatched dimensions in either range must all have extent == + // 1 + while (base_dim < base_ranges.size()) { + ICHECK(is_one(base_ranges[base_dim]->extent)) + << "base_ranges has extra non-1 extent at dim " << base_dim; + ++base_dim; + } + while (src_dim < src_range.size()) { + ICHECK(is_one(src_range[src_dim]->extent)) + << "src_range has extra non-1 extent at dim " << src_dim; + ++src_dim; + } + + Array loop_vars; + size_t idx = 0; + for (size_t i = 0; i < base_ranges.size(); i++) { + if (is_one(base_ranges[i]->extent)) + continue; + Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype); + idx++; + loop_vars.push_back( + {Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar}); + } + return loop_vars; +} + +/*! + * \brief Create s for the copy operation. + * This function generates the actual index expressions for accessing source or + * destination buffers. For dimensions with extent=1, it uses the range minimum; + * for others, it adds the iteration variable. \param ivs Array of IterVar + * returned by MakeIterVars(). \param src_dst 0 for src_indices, 1 for + * dst_indices. \return Array of PrimExpr representing the indices for the copy + * operation. + */ +Array CopyNode::MakeIndices(const Array &ivs, + int src_dst) const { + Array indices; + Array ranges = src_dst == 0 ? src_range : dst_range; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + indices.push_back(ranges[i]->min); + else { + indices.push_back(ranges[i]->min + ivs[idx]->var); + idx++; + } + } + ICHECK(idx == ivs.size()) + << "idx = " << idx << ", ivs.size() = " << ivs.size() + << "src name = " << src->name << ", dst name = " << dst->name; + return indices; +} + +/** + * @brief Build a boundary predicate that guards memory accesses for the copy. + * + * Constructs a conjunction of per-dimension bounds checks (e.g. `min + iv < + * extent` and `min + iv >= 0`) for every dynamic dimension involved in the + * copy. Uses the provided arithmetic analyzer to elide checks that can be + * proven statically. + * + * The function ICHECKs that the supplied `extents` align with the operator's + * recorded ranges for the selected side (source when `src_dst == 0`, + * destination when `src_dst == 1`). + * + * @param ivs IterVars corresponding to the varying dimensions of the copy. Each + * IterVar maps to a non-unit extent dimension in the stored ranges. + * @param extents Extents of the tensor being accessed (must match the number of + * ranges); used as the upper bounds for generated checks. + * @param src_dst Selects which side's ranges to use: `0` for source, `1` for + * destination. + * @return PrimExpr A conjunction of necessary bounds checks, or an empty + * `PrimExpr` (null) if all checks are provably true and no predicate is + * required. + */ +PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, int src_dst) const { + Array ranges = src_dst == 0 ? src_range : dst_range; + + Array cond_list; + ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + continue; + PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i]; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + cond = ranges[i]->min + ivs[idx]->var >= 0; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + idx++; + } + if (cond_list.empty()) + return {}; + else { + PrimExpr cond = cond_list[0]; + for (size_t i = 1; i < cond_list.size(); i++) + cond = And(cond, cond_list[i]); + return cond; + } +} + +/** + * @brief Construct a SIMT-style nested loop that implements the copy. + * + * Builds a loop nest that performs element-wise loads from the source buffer + * and stores into the destination buffer. For a scalar copy (no varying + * iteration dimensions) this returns a single serial loop executing one + * store. For multi-dimensional copies it: + * - creates data-parallel loops (Parallel For) for each varying dimension, + * - binds the resulting iteration variables to the provided arithmetic + * analyzer for simplification, + * - computes source and destination index expressions, + * - applies per-buffer boundary predicates (if needed) to mask out-of-range + * accesses, + * - inserts a cast when src and dst dtypes differ, + * - applies an optional `coalesced_width` annotation to generated parallel + * loops when present. + * + * @param analyzer Analyzer used to simplify and bind loop variable domains. + * @return For A nested For statement representing the generated SIMT loop nest. + */ +For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { + Array loop_vars = MakeIterVars(); + bool is_scalar = loop_vars.empty(); + + for (const auto &iv : loop_vars) + analyzer->Bind(iv->var, iv->dom); + ICHECK(loop_vars.size() <= src_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", src_range.size() = " << src_range.size() << ", src = " << src->name + << ", dst = " << dst->name; + + ICHECK(loop_vars.size() <= dst_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name + << ", dst = " << dst->name; + + Array src_indices = MakeIndices(loop_vars, 0); + Array dst_indices = MakeIndices(loop_vars, 1); + + PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); + PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); + + PrimExpr value = BufferLoad(src, src_indices); + if (src->dtype != dst->dtype) + value = Cast(dst->dtype, value); + if (src_predicate.defined()) + value = if_then_else(src_predicate, value, make_zero(dst->dtype)); + + Stmt body = BufferStore(dst, value, dst_indices); + if (dst_predicate.defined()) + body = IfThenElse(dst_predicate, body); + if (is_scalar) { + return For(Var("i"), 0, 1, ForKind::kSerial, body); + } + for (int i = loop_vars.size() - 1; i >= 0; i--) { + Map annotations = {}; + if (coalesced_width.defined()) { + annotations.Set("coalesced_width", coalesced_width); + } + body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, + ForKind::kParallel, body, std::nullopt, annotations); + } + return Downcast(body); +} + +/** + * @brief Compute a linearized shared-memory layout used for TMA transfers. + * + * Creates a Layout that maps an N-D shared tensor into a 1-D-like ordering + * suitable for TMA by blocking each dimension into 256-element tiles and + * splitting each original index into a quotient and remainder. Effectively + * transforms each index i_k into two coordinates: floor(i_k / 256) and + * i_k % 256, producing an ordering equivalent to concatenating all quotients + * followed by all remainders. + * + * @param shared_tensor The shared-memory buffer whose shape defines the input + * dimensions for the layout inference. + * @return Layout A Layout describing the linearized ordering for the TMA copy. + */ +Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const { + Array input_size = shared_tensor->shape; + Array forward_vars; + for (size_t i = 0; i < input_size.size(); i++) { + forward_vars.push_back(InputPlaceholder(i)); + } + // [i, j] -> [i // 256, j // 256, i % 256, j % 256] + Array forward_index; + for (size_t i = 0; i < input_size.size(); i++) { + forward_index.push_back(FloorDiv(forward_vars[i], 256)); + } + for (size_t i = 0; i < input_size.size(); i++) { + forward_index.push_back(FloorMod(forward_vars[i], 256)); + } + return Layout(input_size, forward_index); +} + +/** + * @brief Infer memory layouts for this Copy operation. + * + * Determines an appropriate LayoutMap for the copy based on the target and + * enabled lowering paths. For TMA-capable targets when the chosen copy + * instruction is BulkLoad or BulkStore, this may produce a linearized shared + * memory layout suitable for TMA transfers (only when inference is invoked at + * InferLevel::kFree and no layout for the shared buffer is already annotated). + * For other cases (including LDSM/STSM and the normal copy path), layout + * inference is delegated to the SIMT parallel operation produced by + * MakeSIMTLoop(). + * + * This method may read PassContext configuration (kDisableTMALower) and may + * lazily construct and cache the parallel operation in par_op_ as a side + * effect. + * + * @param T LayoutInferArgs containing target and the current layout map. + * @param level The inference level controlling how aggressive/layouts may be + * proposed. + * @return LayoutMap mapping buffers to inferred layouts (may be empty if no + * additional layouts are suggested). + */ +LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + auto target = T.target; + using namespace tvm::transform; + PassContext pass_ctx = PassContext::Current(); + bool disable_tma_lower = + pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); + auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, + T.layout_map, T.analyzer, T.buffer_oob); + + // Handle tensor memory (tmem) layout inference + if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { + // Tensor memory copy + // TODO (mzw) Add support for tcgen05.st/cp (in conj. with LowerTmemCopy) + ICHECK(copy_inst == CopyInst::kTMemLoad) + << "Only support tensor memory copy from shared.tmem to local.fragment " + "currently"; + LayoutMap results; + if (!T.layout_map.count(dst) && T.layout_map.count(src)) { + // Use the default layout (32dp32b) if not specified + // NOTE (mzw) We will check the layout in LowerTmemCopy(), so don't + // worry for tmem-incompatible layout + Layout src_layout = T.layout_map[src]; + Array logical_coords = MakeIterVars(); + Array logical_coords_var = {logical_coords[0]->var, + logical_coords[1]->var}; + Array phy_indices = src_layout->Forward(logical_coords_var); + + // Tmem physical coord range analysis + auto analyzer = std::make_shared(); + for (const auto &iv : logical_coords) + analyzer->Bind(iv->var, iv->dom); + arith::ConstIntBound phy_row_bounds = + analyzer->const_int_bound(phy_indices[0]); + arith::ConstIntBound phy_col_bounds = + analyzer->const_int_bound(phy_indices[1]); + Range row_dom = Range((int)(phy_row_bounds->min_value), + (int)(phy_row_bounds->max_value + 1)); + Range col_dom = Range((int)(phy_col_bounds->min_value), + (int)(phy_col_bounds->max_value + 1)); + + constexpr int WARP_SIZE = 32; // Set to 32 since only sm100 is supported + constexpr int WARPGROUP_SIZE = 4 * WARP_SIZE; + ICHECK(is_const_int(T.thread_bounds->extent)) + << "Tensor memory copy requires thread_bounds->extent (num_threads) " + "to be constant integers"; + int num_threads = *as_const_int(T.thread_bounds->extent); + ICHECK(num_threads % WARPGROUP_SIZE == 0) + << "Tensor memory copy requires thread bounds to be aligned to " + "warpgroups, but found " + << "thread range = " << T.thread_bounds; + + for (int num_useful_wgs = num_threads / WARPGROUP_SIZE; + num_useful_wgs >= 1; --num_useful_wgs) { + int num_useful_threads = num_useful_wgs * WARPGROUP_SIZE; + Tcgen05Meta meta = getTcgen05Meta_32dp32b(); + auto [is_success, tmem_coord2frag, num_chunks_each_wg] = + expandTcgen05Layout( + meta, phy_col_bounds->max_value - phy_col_bounds->min_value + 1, + num_useful_threads, row_dom, col_dom); + if (!is_success) { + continue; + } + Fragment logical_coord2frag = + Fragment(logical_coords, tmem_coord2frag->Forward(phy_indices), + tmem_coord2frag->ForwardThread(phy_indices, std::nullopt), + make_itervar("rep", 1)); + results.Set(dst, logical_coord2frag->BindThreadRange(T.thread_bounds)); + break; + } + } + return results; + } + + if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) { + // if can apply swizzling, we skip layout inference + // for bulk load/store, we can directly apply the layout of normal copy + // This must be a global/shared layout, so we can skip the parallel op + // layout inference (parallel layout inference only annotate the loop layout + // and the register layout). + bool is_load = + copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D; + Buffer global_tensor = is_load ? src : dst; + Buffer shared_tensor = is_load ? dst : src; + + Map result_map; + + // Collect fragment buffers from indices and mark them as fully replicated + // For Bulk Load/Store, fragment buffers used as indices should be + // replicated across all threads + PrimExpr thread_extent = T.thread_bounds->extent; + for (const auto &range : src_range) { + CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map, + thread_extent, T.thread_bounds, result_map); + CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map, + thread_extent, T.thread_bounds, result_map); + } + for (const auto &range : dst_range) { + CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map, + thread_extent, T.thread_bounds, result_map); + CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map, + thread_extent, T.thread_bounds, result_map); + } + + // check shared layout is non-swizzle + // skip layout inference if shared layout is already annotated + if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) { + // create a new layout map for tma linear layout + Layout linear_layout = ComputeLinearLayout(shared_tensor); + result_map.Set(shared_tensor, linear_layout); + } + return result_map; + } + // for LDSM/STSM, the layout was deduced from register layout + // so we can directly apply the layout of normal copy + // Use parallel op to infer the layout + if (!par_op_.defined()) { + arith::Analyzer analyzer; + par_op_ = ParallelOp((MakeSIMTLoop(&analyzer))); + } + auto layout_map = par_op_->InferLayout(T, level); + return layout_map; +} +/** + * @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA) + * instruction. + * + * The function returns true when all of the following hold: + * - the target architecture advertises bulk-copy/TMA support; + * - the source buffer resides in global memory; + * - the destination buffer resides in shared memory (either "shared" or + * "shared.dyn"); + * - the source and destination have the same element data type. + * + * If the source and destination dtypes differ, a warning is logged and the + * function returns false (the caller is expected to fall back to a normal + * copy). + * + * @param target The compilation target to query for bulk-copy support. + * @return true if the copy can be implemented as a Bulk Load (TMA); false + * otherwise. + */ +bool CopyNode::CheckBulkLoad(Target target, arith::Analyzer *analyzer, + bool check_last_dim) const { + // 1. arch must have bulk copy support + if (!TargetHasBulkCopy(target)) + return false; + // 2. src and dst must be global and shared + if (src.scope() != "global" || + (dst.scope() != "shared.dyn" && dst.scope() != "shared")) + return false; + // 3. check shape. + // last dim of src * dtype.bits() must be a multiple of 16 + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // now we check src (gmem) as tma box dim is deduced from src + if (check_last_dim && + analyzer->CanProve( + FloorMod(src_range[src_range.size() - 1]->extent * src->dtype.bytes(), + 16) != 0, + arith::ProofStrength::kSymbolicBound)) { + LOG(WARNING) + << "src range must have last dim multiple of 16 for tma bulk load " + << src->name << " range " << src_range[src_range.size() - 1]->extent + << " * " << src->dtype.bytes() << " % 16 != 0"; + return false; + } + + // 4. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for tma load " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + +bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor, + const Buffer &shared_tensor, + const Array &global_range, + const Array &shared_range, + const LayoutMap &layout_map, + arith::Analyzer *analyzer) const { + + // Step 1: check shared is contiguous + bool shared_is_contiguous = true; + if (layout_map.count(shared_tensor)) { + shared_is_contiguous = false; + } + // Step 2: check global is contiguous + bool global_is_contiguous = true; + bool global_not_full_dim_encounter = false; + for (int i = global_range.size() - 1; i >= 0; i--) { + if (!global_not_full_dim_encounter) { + if (!analyzer->CanProve(global_range[i]->extent == + global_tensor->shape[i] && + global_range[i]->min == 0, + arith::ProofStrength::kSymbolicBound)) { + global_not_full_dim_encounter = true; + } + } else { + if (!analyzer->CanProve(global_range[i]->extent == 1, + arith::ProofStrength::kSymbolicBound)) { + global_is_contiguous = false; + break; + } + } + } + + // Step 3: check element match and no OOB + PrimExpr shared_elements = 1; + for (size_t i = 0; i < shared_range.size(); i++) { + shared_elements *= shared_range[i]->extent; + } + PrimExpr global_elements = 1; + for (size_t i = 0; i < global_range.size(); i++) { + global_elements *= global_range[i]->extent; + } + bool element_match = + analyzer->CanProveEqual(shared_elements, global_elements); + + return (shared_is_contiguous && global_is_contiguous && element_match); +} + +bool CopyNode::CheckBulkLoad1D(Target target, const LayoutMap &layout_map, + arith::Analyzer *analyzer) const { + if (!CheckBulkLoad(target, analyzer, false)) + return false; + auto global_tensor = src; + auto shared_tensor = dst; + auto global_range = src_range; + auto shared_range = dst_range; + return CheckBulkCopy1D(global_tensor, shared_tensor, global_range, + shared_range, layout_map, analyzer); +} + +bool CopyNode::CheckBulkStore1D(Target target, const LayoutMap &layout_map, + arith::Analyzer *analyzer) const { + if (!CheckBulkStore(target, analyzer, false)) + return false; + auto shared_tensor = src; + auto global_tensor = dst; + auto shared_range = src_range; + auto global_range = dst_range; + return CheckBulkCopy1D(global_tensor, shared_tensor, global_range, + shared_range, layout_map, analyzer); +} + +/** + * @brief Determine if this CopyNode can be lowered to a CUDA BulkStore (TMA + * store). + * + * Checks whether the target supports bulk copy, the source buffer is in shared + * memory (shared or shared.dyn), the destination buffer is in global memory, + * and both buffers have the same element data type. If the data types differ, + * a warning is logged and false is returned. + * + * @param target Target device/architecture to check for bulk-copy support. + * @return true if all conditions for a BulkStore are met; false otherwise. + */ +bool CopyNode::CheckBulkStore(Target target, arith::Analyzer *analyzer, + bool check_last_dim) const { + // 1. arch must have bulk copy support + if (!TargetHasBulkCopy(target)) + return false; + // 2. src and dst must be shared.dyn and local.fragment + if ((src.scope() != "shared.dyn" && src.scope() != "shared") || + dst.scope() != "global") + return false; + // 3. check shape. + // last dim of dst * dtype.bits() must be a multiple of 16 + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // now we check dst (gmem) as tma box dim is deduced from dst + if (check_last_dim && + analyzer->CanProve( + FloorMod(dst_range[dst_range.size() - 1]->extent * dst->dtype.bytes(), + 16) != 0, + arith::ProofStrength::kSymbolicBound)) { + LOG(WARNING) + << "dst range must have last dim multiple of 16 for tma bulk store " + << dst->name << " range " << dst_range[dst_range.size() - 1]->extent + << " * " << dst->dtype.bytes() << " % 16 != 0"; + return false; + } + // 4. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for tma store " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + +/*! + * \brief Check if the copy operation is a LDSM copy. + * This function verifies if the copy operation can be implemented using CUDA's + * Load Matrix (LDSM) instruction. Requirements include: target supports + * LDMATRIX, source is shared.dyn, destination is local.fragment. \param target + * Target device. \return True if the copy operation is a LDSM copy, false + * otherwise. + */ +bool CopyNode::CheckLDSMCopy(Target target) const { + return TargetHasLdmatrix(target) && + (src.scope() == "shared.dyn" || src.scope() == "shared") && + dst.scope() == "local.fragment"; +} + +/** + * @brief Determine whether this copy can use the STMATRIX store (STSM) path. + * + * Returns true when the target supports STMATRIX and the source buffer is in + * the `local.fragment` scope while the destination buffer is in shared memory + * (`shared` or `shared.dyn`). + * + * @param target The compilation target to query for STMATRIX support. + * @return true if the copy may be lowered to an STSM instruction; false + * otherwise. + */ +bool CopyNode::CheckSTSMCopy(Target target) const { + return TargetHasStmatrix(target) && src.scope() == "local.fragment" && + (dst.scope() == "shared.dyn" || dst.scope() == "shared"); +} + +/** + * @brief Determine whether this copy can use tensor memory load (tcgen05.ld). + * + * Returns true when the target supports tensor memory and the source buffer is + * in `shared.tmem` scope while the destination buffer is in `local.fragment`. + * + * @param target The compilation target to query for tensor memory support. + * @return true if the copy may be lowered to a tcgen05.ld instruction; false + * otherwise. + */ +bool CopyNode::CheckTMemLoad(Target target) const { + return TargetHasTmem(target) && src.scope() == "shared.tmem" && + dst.scope() == "local.fragment"; +} + +/** + * @brief Determine whether this copy can use tensor memory store (tcgen05.st). + * + * Returns true when the target supports tensor memory and the source buffer is + * in `local.fragment` scope while the destination buffer is in `shared.tmem`. + * + * @param target The compilation target to query for tensor memory support. + * @return true if the copy may be lowered to a tcgen05.st instruction; false + * otherwise. + */ +bool CopyNode::CheckTMemStore(Target target) const { + return TargetHasTmem(target) && src.scope() == "local.fragment" && + dst.scope() == "shared.tmem"; +} + +/** + * @brief Selects the most specific copy instruction supported for the given + * target and buffers. + * + * Determines which specialized copy lowering to use (TMA bulk load/store, LDSM, + * STSM, TMem load/store) based on target capabilities and the memory scopes of + * the source/destination buffers. If TMA lowering is disabled via the flag, + * BulkLoad/BulkStore are not selected. The selection priority is: TMemLoad, + * TMemStore, BulkLoad1D, BulkStore1D, BulkLoad, BulkStore, LDSM, STSM, then + * Normal (fallback). + * + * @param target The compilation target used to query hardware capabilities. + * @param disable_tma_lower If true, prevents selecting TMA-based bulk + * load/store instructions. + * @return CopyInst The chosen copy instruction enum value. + */ +CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, + const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob = false) const { + // disable_tma_lower is from pass_configs + // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, + // we will not use tma for bulk load/store + + // Check tensor memory operations first (highest priority for SM100/Blackwell) + // 1d tma access can not support out of bound access + if (!disable_tma_lower && !buffer_oob && + CheckBulkLoad1D(target, layout_map, analyzer)) { + return CopyInst::kBulkLoad1D; + } else if (!disable_tma_lower && !buffer_oob && + CheckBulkStore1D(target, layout_map, analyzer)) { + return CopyInst::kBulkStore1D; + } else if (!disable_tma_lower && CheckBulkLoad(target, analyzer)) { + return CopyInst::kBulkLoad; + } else if (!disable_tma_lower && CheckBulkStore(target, analyzer)) { + return CopyInst::kBulkStore; + } else if (CheckLDSMCopy(target)) { + return CopyInst::kLDSM; + } else if (CheckSTSMCopy(target)) { + return CopyInst::kSTSM; + } else if (CheckTMemLoad(target)) { + return CopyInst::kTMemLoad; + } else if (CheckTMemStore(target)) { + return CopyInst::kTMemStore; + } else { + return CopyInst::kNormal; + } +} + +/*! + * \brief Lower the copy operation to PTX code. + * This function converts the high-level copy operation into low-level PTX + * instructions. It dispatches to specialized lowering functions based on the + * determined copy instruction type: + * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions + * - LDSM/STSM: Uses matrix load/store instructions for tensor cores + * - Normal: Uses standard load/store operations with loop transformations + * \param T LowerArgs containing target and layout map. + * \param analyzer Arithmetic analyzer for simplification. + * \return Stmt representing the PTX code for the copy operation. + */ +Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + Target target = T.target; + + using namespace tvm::transform; + PassContext pass_ctx = PassContext::Current(); + bool disable_tma_lower = + pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); + auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, + T.layout_map, analyzer); + if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { + auto tmem_copy = LowerTmemCopy(T, analyzer); + ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; + return tmem_copy; + } else if (copy_inst == CopyInst::kBulkLoad1D || + copy_inst == CopyInst::kBulkStore1D) { + auto bulk_copy = LowerBulkCopy1D(T, analyzer, copy_inst); + ICHECK(bulk_copy.defined()) << "Failed to lower bulk load 1d"; + return bulk_copy; + } else if (copy_inst == CopyInst::kBulkLoad || + copy_inst == CopyInst::kBulkStore) { + auto bulk_copy = LowerBulkCopy(T, analyzer, copy_inst); + ICHECK(bulk_copy.defined()) << "Failed to lower bulk load/store"; + return bulk_copy; + } else if (copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) { + auto ldsm_copy = LowerLDSMCopy(T, analyzer, copy_inst); + ICHECK(ldsm_copy.defined()) << "Failed to lower ptx matrix copy"; + return ldsm_copy; + } else if (copy_inst == CopyInst::kNormal) { + return LowerNormalCopy(T, analyzer); + } else { + LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); + } +} + +/** + * @brief Lower the copy operator using the generic (non-specialized) path. + * + * Generates standard load/store code paths for targets that cannot or should + * not use specialized copy instructions (TMA, LDSM/STSM). Builds a SIMT loop, + * fuses and transforms parallel loops, infers and applies loop layouts on GPU + * targets, partitions by thread, and applies vectorization appropriate to the + * device (CPU or GPU). If a thread-level predicate is required, the resulting + * body is guarded with an IfThenElse. + * + * @param T Lowering context including the target, thread bounds, thread var, + * layout map, and buffer remapping used during layout inference and + * loop partitioning. + * @param analyzer Arithmetic analyzer used to simplify and reason about bounds + * during loop partitioning and predicate construction. + * @return Stmt Lowered statement representing the transformed, vectorized + * normal-copy loop (possibly wrapped in a predicate). + */ +Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; + auto simt_loop = MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + + auto transformed_loop = + Downcast(ParallelLoopTransformer::Substitute(fused_loop)); + + For vectorized_thread_loop; + auto par_op = ParallelOp(transformed_loop); + + if (is_cpu_target || dst.scope() == "local" || src.scope() == "local") { + if (src.scope() == "local" && dst.scope() != "local") { + LOG(WARNING) << "Copy from local buffer `" << src->name << "` to " + << dst.scope() << " buffer `" << dst->name + << "` may cause conflicted write."; + } + vectorized_thread_loop = VectorizeLoop(transformed_loop); + } else { + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + auto thread_var = T.thread_var; + auto thread_loop = + PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); + vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); + } + + if (par_op->GetPredicate(T.thread_var).defined()) { + return IfThenElse(par_op->GetPredicate(T.thread_var).value(), + vectorized_thread_loop); + } + return vectorized_thread_loop; +} + +/** + * @brief Lower a Copy operator to LDSM/STSM (warp-level 8x8 matrix) + * instructions. + * + * Lowers a CopyNode into PTX matrix load/store (LDSM/STSM) sequences when the + * access/layouts meet the hardware constraints required by warp-level 8x8 + * fragment transfers (thread-mapped 8x8 fragment layout, 16-byte contiguous + * shared memory accesses, full-range local tiles, matching dtypes for loads, + * and no access predicates). If these conditions are not met the function + * falls back to lowering via LowerNormalCopy(). + * + * The routine validates layout/thread-mapping compatibility (including support + * for transposed fragment layouts), determines vectorization factor (4/2/1) + * based on extent alignment, computes shared/local addresses, emits the + * appropriate ptx_ldmatrix/ptx_stmatrix call(s), and wraps them in a small + * loop that may be unrolled and adjusted for thread-bounds offsets. + * + * @param T Lowering context (target, layout/ buffer remaps, thread/ bounds). + * @param analyzer Arithmetic analyzer used to simplify and prove bounds. + * @param copy_inst Must be either CopyInst::kLDSM or CopyInst::kSTSM to select + * matrix-load vs matrix-store lowering. + * @return Stmt A statement implementing the LDSM/STSM lowering, or the result + * of LowerNormalCopy(...) when constraints require fallback. + */ +Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { + ICHECK(copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) + << "Invalid copy inst " << static_cast(copy_inst); + bool is_ldmatrix = copy_inst == CopyInst::kLDSM; + + // Check no predicates + Array loop_vars = MakeIterVars(); + if (loop_vars.size() < 2) { + // cannot support 1-d case + return LowerNormalCopy(T, analyzer); + } + for (const auto &iv : loop_vars) + analyzer->Bind(iv->var, iv->dom); + PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); + PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); + if (src_predicate.defined() || dst_predicate.defined()) { + // stmatrix and ldmatrix can only support no predicate + return LowerNormalCopy(T, analyzer); + } + + Buffer shared_tensor = is_ldmatrix ? src : dst; + Buffer local_tensor = is_ldmatrix ? dst : src; + + Array local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0); + Fragment local_layout = Downcast(T.layout_map[local_tensor]); + Array local_indices_transformed = + local_layout->Forward(local_indices); + local_tensor = T.buffer_remap[local_tensor]; + // currently only support 1-d case + if (local_layout->OutputDim() != 1) { + // TMA ldmatrix/stmatrix cannot support non-1-d layout, will be fallback to + // normal copy + return LowerNormalCopy(T, analyzer); + } + + Array shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1); + Array shared_indices_transformed = shared_indices; + Layout shared_layout; + if (T.buffer_remap.count(shared_tensor)) { + shared_layout = T.layout_map[shared_tensor]; + shared_tensor = T.buffer_remap[shared_tensor]; + shared_indices_transformed = shared_layout->Forward(shared_indices); + } + + // Check local_layout follows 8x8 layout + // LDSM/STSM instructions require 8x8 matrix fragment layout + // This matches the warp-level matrix multiplication pattern used in tensor + // cores We check both normal and transposed layouts to support different + // access patterns + bool is_transposed; + IterVar col_var = loop_vars[loop_vars.size() - 1]; + IterVar row_var = loop_vars[loop_vars.size() - 2]; + PrimExpr local_layout_thread_map = + FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32); + PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread( + {FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt); + PrimExpr matrix_8x8_thread_map_trans = + makeGemmFragment8x8Transposed()->ForwardThread( + {FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt); + PrimExpr local_indices_flattened = + local_tensor.OffsetOf(local_indices_transformed).back(); + if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) && + IndiceCanVectorize(local_indices_flattened, col_var->var, + col_var->dom->extent, 2, analyzer)) { + is_transposed = false; + } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans, + local_layout_thread_map) && + IndiceCanVectorize(local_indices_flattened, row_var->var, + row_var->dom->extent, 2, analyzer)) { + is_transposed = true; + } else { + // TMA ldmatrix/stmatrix cannot support non-8x8 layout, will be fallback to + // normal copy + return LowerNormalCopy(T, analyzer); + } + // Check shared_layout is 16 bytes continuous + // LDSM/STSM instructions require 16-byte aligned data (half-precision floats) + // This is a hardware constraint for matrix load/store operations + if (shared_tensor->dtype.bytes() != 2) { + // TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will + // be fallback to normal copy + return LowerNormalCopy(T, analyzer); + } + PrimExpr flattened_indice = + shared_tensor.OffsetOf(shared_indices_transformed).back(); + if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var, + loop_vars.back()->dom->extent, 8, analyzer)) { + // TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will + // be fallback to normal copy + return LowerNormalCopy(T, analyzer); + } + + // Can only support local_range to be a full range + for (size_t i = 0; i < dst_range.size(); i++) { + if (!is_zero(dst_range[i]->min) || + !analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i])) + // TMA ldmatrix/stmatrix cannot support non-full range, will be fallback + // to normal copy + return LowerNormalCopy(T, analyzer); + } + + // Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1 + PrimExpr extent = local_tensor->shape[0]; + int num = 1; + if (analyzer->CanProveEqual(FloorMod(extent, 8), 0)) + num = 4; + else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0)) + num = 2; + + Array args; + const Op &op = is_ldmatrix ? tl::ptx_ldmatrix() : tl::ptx_stmatrix(); + args.push_back(static_cast(is_transposed)); + args.push_back(num); + + // Create shared address with regard to local address + // if not transpose + // coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4)) + // if transpose + // coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread + // % 8 / 2) + Var local_iter("i"); + Layout inv = local_layout->Inverse(); + Array shared_coords; + PrimExpr warp = FloorDiv(T.thread_var, 32) * 32; + if (!is_transposed) + shared_coords = inv->Forward( + {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num), + warp + FloorMod(T.thread_var, 8) * 4}); + else + shared_coords = inv->Forward( + {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) + + FloorMod(T.thread_var, 2), + warp + FloorDiv(FloorMod(T.thread_var, 8), 2)}); + shared_coords.pop_back(); // remove rep + if (shared_layout.defined()) + shared_coords = shared_layout->Forward(shared_coords); + PrimExpr shared_addr = shared_tensor.access_ptr( + is_ldmatrix ? 1 : 2, DataType::Handle(), 1, + shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num)); + args.push_back(shared_addr); + + if (is_ldmatrix) { + // Can only support same dtype for ldmatrx + if (local_tensor->dtype != shared_tensor->dtype) { + // TMA ldmatrix cannot support different dtype, will be fallback to normal + // copy + return LowerNormalCopy(T, analyzer); + } + PrimExpr local_addr = local_tensor.access_ptr( + 2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num)); + args.push_back(local_addr); + } else { + for (int i = 0; i < num; i++) { + PrimExpr value0 = + BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i}); + PrimExpr value1 = + BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1}); + if (local_tensor->dtype != shared_tensor->dtype) { + value0 = Cast(shared_tensor->dtype, value0); + value1 = Cast(shared_tensor->dtype, value1); + } + PrimExpr value_packed = + Call(DataType::Int(32), pack_b16(), {value0, value1}); + args.push_back(value_packed); + } + } + + auto body = Evaluate(Call(DataType::Handle(), op, args)); + For for_node = + For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body); + for_node = LoopPragmaUnroll(for_node); + auto range = T.thread_bounds; + if (range.defined()) { + auto thread_var = T.thread_var; + auto thread_var_with_offset = thread_var - range->min; + for_node.CopyOnWrite()->body = + Substitute(for_node->body, {{thread_var, thread_var_with_offset}}); + } + return for_node; +} + +/** + * @brief Lower tensor memory copy operations (tcgen05.ld/st/cp). + * + * Handles copy operations involving shared.tmem buffers (tensor memory on + * SM100/Blackwell). Supports three types of tensor memory copies: + * - tcgen05.ld: tensor memory -> register (local.fragment) + * - tcgen05.st: register (local.fragment) -> tensor memory + * - tcgen05.cp: shared memory -> tensor memory + * + * The function validates buffer scopes, extracts 2D loop structure, performs + * layout compatibility checks, selects an appropriate TCGEN05 instruction + * variant based on data width and thread count, and emits the corresponding PTX + * intrinsic call. + * + * Currently only tcgen05.ld is fully supported; st/cp will trigger an ICHECK + * failure. + * + * @param T Lowering context (target, thread bounds, layout maps, buffer + * remaps). + * @param analyzer Arithmetic analyzer for proving bounds and simplifying + * expressions. + * @return Stmt The lowered tensor memory copy statement, or an empty Stmt if + * this copy does not involve tensor memory. + */ +Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + if (src.scope() != "shared.tmem" && dst.scope() != "shared.tmem") { + return Stmt(); + } + ICHECK(TargetHasTmem(T.target)) << "Target " << T.target->ToDebugString() + << " does not support tensor memory copy"; + + // Decide copy type + bool is_ld = false; // tcgen05.ld (tensor memory -> register) + bool is_st = false; // tcgen05.st (register -> tensor memory) + bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) + bool src_needs_pack = + 16 == src->dtype.bits(); // if needs .pack::16b when is_ld + bool dst_needs_unpack = + 16 == dst->dtype.bits(); // if needs .unpack::16b when is_st + + if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { + is_ld = true; + } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { + is_st = true; + } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { + is_cp = true; + } else { + ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = " + << src.scope() << ", dst scope = " << dst.scope(); + } + // Currently tcgen05.cp is not supported + // TODO (mzw) Support tcgen05.cp + ICHECK(!is_cp) + << "Copy from shared memory to tensor memory is not supported yet"; + // Currently tcgen05.st is not supported + // TODO (mzw) Support tcgen05.st + ICHECK(!is_st) << "Copy from register to tensor memory is not supported yet"; + + // Extract loop variables and ranges + Array loop_vars = MakeIterVars(); + ICHECK(loop_vars.size() == 2) << "Only support 2D tensor memory copy, got " + << loop_vars.size() << " dimensions"; + for (const auto &iv : loop_vars) + analyzer->Bind(iv->var, iv->dom); + PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); + PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); + ICHECK(!src_predicate.defined() && !dst_predicate.defined()) + << "Tensor memory copy does not support predicates, got " << src_predicate + << " and " << dst_predicate; + ICHECK(is_const_int(loop_vars[0]->dom->min) && + is_const_int(loop_vars[0]->dom->extent) && + is_const_int(loop_vars[1]->dom->min) && + is_const_int(loop_vars[1]->dom->extent)) + << "Tensor memory copy requires loop bounds to be constant integers"; + int64_t logical_row_min = *as_const_int(loop_vars[0]->dom->min); + int64_t logical_row_extent = *as_const_int(loop_vars[0]->dom->extent); + int64_t logical_col_min = *as_const_int(loop_vars[1]->dom->min); + int64_t logical_col_extent = *as_const_int(loop_vars[1]->dom->extent); + + // Extract thread bounds + constexpr int WARP_SIZE = 32; // Set to 32 since only sm100 is supported + constexpr int WARPGROUP_SIZE = 4 * WARP_SIZE; + ICHECK(is_const_int(T.thread_bounds->extent)) + << "Tensor memory copy requires thread_bounds->extent (num_threads) to " + "be constant integers"; + int num_threads = *as_const_int(T.thread_bounds->extent); + ICHECK(analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, WARPGROUP_SIZE), + 0) && + num_threads % WARPGROUP_SIZE == 0) + << "Tensor memory copy requires thread bounds to be aligned to " + "warpgroups, but found " + << "thread range = " << T.thread_bounds; + + // TODO (mzw) Buffer remap for shared.dyn when is_cp is true? + + // Retrieve layout + ICHECK(T.layout_map.count(src)) + << "Source buffer " << src->name << " does not have a layout specified"; + ICHECK(T.layout_map.count(dst)) << "Destination buffer " << dst->name + << " does not have a layout specified"; + Layout src_layout = T.layout_map[src]; + Fragment dst_layout = Downcast(T.layout_map[dst]); + + // Check layout + Array logical_indices = MakeIndices(loop_vars, 0); + Array phy_indices = + src_layout->Forward(logical_indices); // "phy" for "physical" + + // Analyse the range of tmem_phy_row and tmem_phy_col + arith::ConstIntBound phy_row_bounds = + analyzer->const_int_bound(phy_indices[0]); + arith::ConstIntBound phy_col_bounds = + analyzer->const_int_bound(phy_indices[1]); + int tmem_phy_row_min = phy_row_bounds->min_value; + int tmem_phy_row_max = phy_row_bounds->max_value; + int tmem_phy_col_min = phy_col_bounds->min_value; + int tmem_phy_col_max = phy_col_bounds->max_value; + int tmem_phy_row_extent = tmem_phy_row_max - tmem_phy_row_min + 1; + int tmem_phy_col_extent = tmem_phy_col_max - tmem_phy_col_min + 1; + Range row_dom = Range(tmem_phy_row_min, tmem_phy_row_max + 1); + Range col_dom = Range(tmem_phy_col_min, tmem_phy_col_max + 1); + + bool have_succeeded = false; + Stmt body; + + auto try_tcgen05_instruction = [&](Tcgen05Meta meta) { + if (have_succeeded) { + return; + } + if (tmem_phy_row_min != 0 || tmem_phy_row_max != 127) { + return; + } + if (tmem_phy_col_min % meta.width != 0 || + (tmem_phy_col_max + 1) % meta.width != 0) { + return; + } + + for (int num_useful_wgs = num_threads / WARPGROUP_SIZE; num_useful_wgs >= 1; + num_useful_wgs--) { + int num_useful_threads = num_useful_wgs * WARPGROUP_SIZE; + auto [is_success, target_frag, num_chunks_each_wg] = expandTcgen05Layout( + meta, tmem_phy_col_extent, num_useful_threads, row_dom, col_dom); + if (!is_success) { + continue; + } + + PrimExpr target_thread = + target_frag->ForwardThread(phy_indices, std::nullopt); + PrimExpr dst_thread = + dst_layout->ForwardThread(logical_indices, std::nullopt); + if (!analyzer->CanProveEqual(target_thread, dst_thread)) { + continue; + } + PrimExpr target_reg = target_frag->Forward(phy_indices)[0]; + PrimExpr dst_reg = dst_layout->Forward(logical_indices)[0]; + if (!analyzer->CanProveEqual(target_reg, dst_reg)) { + continue; + } + + // All checks passed, we can use this instruction + PrimExpr relative_wg_idx = + FloorDiv(Sub(T.thread_var, T.thread_bounds->min), WARPGROUP_SIZE); + PrimExpr col_offset = + num_useful_threads == WARPGROUP_SIZE + ? PrimExpr(0) + : relative_wg_idx * (num_chunks_each_wg * meta.width); + have_succeeded = true; + Array args; + const char *bool_str = src_needs_pack ? "true" : "false"; + args.push_back(StringImm(meta.intrinsics_name + "<" + + std::to_string(num_chunks_each_wg) + ", " + + bool_str + ">")); + args.push_back( + BufferLoad(src, {(int)logical_row_min, + (int)logical_col_min})); // Will be translated later + // in lower_shared_tmem pass + args.push_back(col_offset); + args.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0, + PrimExpr(tmem_phy_col_extent))); + + Stmt call = + Evaluate(Call(DataType::Handle(), builtin::call_extern(), args)); + if (num_useful_threads != num_threads) { + body = + IfThenElse(T.thread_var < T.thread_bounds->min + num_useful_threads, + call, // No-op for unused threads + Stmt()); + } else { + body = call; + } + break; + } + }; + + try_tcgen05_instruction(getTcgen05Meta_32dp32b()); + try_tcgen05_instruction(getTcgen05Meta_32dp64b()); + try_tcgen05_instruction(getTcgen05Meta_32dp128b()); + try_tcgen05_instruction(getTcgen05Meta_32dp256b()); + + ICHECK(have_succeeded) << "Failed to find a suitable instruction for " + "tcgen05.ld. Check your layout."; + + return body; +} + +/** + * @brief Lower a Copy operator to a bulk TMA (Tensor Memory Accelerator) + * transfer. + * + * Lowers the copy to an optimized TMA load or store when the target and buffer + * layouts permit. Constructs a TMADesc, detects shared-memory + * swizzle/interleave patterns, encodes global shape/stride/SMEM parameters, and + * emits either a 1D TMA transfer (when global/shared are contiguous and element + * counts match, currently only for loads) or a full multi-dimensional TMA call. + * The emitted statement is guarded so only the thread with min thread id + * executes the TMA. + * + * If preconditions are not satisfied (unsupported swizzle, stride/size limits, + * mismatched element counts, OOB risks, or other hardware constraints), this + * function falls back to LowerNormalCopy. + * + * @param T LowerArgs containing target information, thread/bounds variables, + * and layout/ buffer remap information used for descriptor + * construction. + * @param analyzer Analyzer used to prove shapes/contiguity/equality + * constraints. + * @param copy_inst Indicates whether to emit a BulkLoad (TMA load) or BulkStore + * (TMA store). Must be CopyInst::kBulkLoad or kBulkStore. + * @return Stmt A TIR statement performing the bulk TMA copy (or the result of + * LowerNormalCopy when falling back). + */ +Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { + ICHECK(copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) + << "Invalid copy inst " << static_cast(copy_inst); + bool is_load = copy_inst == CopyInst::kBulkLoad; + Buffer global_tensor = is_load ? src : dst; + Buffer shared_tensor = is_load ? dst : src; + Array global_range = is_load ? src_range : dst_range; + Array shared_range = is_load ? dst_range : src_range; + // TMA bulk copy cannot support a non-swizzled global layout, will be fallback + // to normal copy + if (T.layout_map.count(global_tensor)) { + LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global " + "layout, fallback to normal copy."; + return LowerNormalCopy(T, analyzer); + } + + // linear layout must be computed before remapping + auto linear_layout = ComputeLinearLayout(shared_tensor); + + Array shared_indices; + for (auto r : shared_range) + shared_indices.push_back(r->min); + std::vector shared_strides; + PrimExpr shared_stride = 1; + for (size_t i = 0; i < shared_tensor->shape.size(); i++) { + auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; + shared_strides.insert(shared_strides.begin(), shared_stride); + shared_stride *= s; + } + + Array global_indices; + for (auto r : global_range) { + global_indices.push_back(r->min); + } + std::vector global_strides; + PrimExpr global_stride = 1; + for (size_t i = 0; i < global_tensor->shape.size(); i++) { + auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; + global_strides.insert(global_strides.begin(), global_stride); + global_stride *= s; + } + + ICHECK(shared_strides.size() == shared_indices.size()) + << "shared_strides.size() != shared_indices.size()" + << shared_strides.size() << " " << shared_indices.size(); + PrimExpr shared_offset = 0; + for (size_t i = 0; i < shared_indices.size(); i++) { + shared_offset += shared_indices[i] * shared_strides[i]; + } + PrimExpr global_offset = 0; + for (size_t i = 0; i < global_indices.size(); i++) { + global_offset += global_indices[i] * global_strides[i]; + } + + TMADesc desc; + // Verify copy rank + desc.rank = global_tensor->shape.size(); + ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank; + + // Verify datatype + ICHECK(global_tensor->dtype == shared_tensor->dtype) + << "Copy between buffer " << global_tensor->name << " and " + << shared_tensor->name << " with different data type " + << global_tensor->dtype << " and " << shared_tensor->dtype; + + desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); + + // Global Tensor Shape and Stride + desc.global_addr = global_tensor->data; + desc.global_shape = ReverseArray(global_tensor->shape); + Array global_coords = + ReverseArray(global_range.Map([](Range r) { return r->min; })); + if (!global_tensor->strides.empty()) { + desc.global_stride = ReverseArray(global_tensor->strides); + } else { + // Create stride from shape + PrimExpr stride = 1; + desc.global_stride.reserve(desc.rank); + for (size_t i = 0; i < desc.rank; i++) { + desc.global_stride.push_back(stride); + stride *= desc.global_shape[i]; + } + } + // The first stride element should be 1 + ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; + // Make global stride in bytes + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return cast(DataType::Int(64), e) * global_tensor->dtype.bytes(); + }); + for (size_t i{1}; i < desc.global_stride.size(); i++) { + auto stride = desc.global_stride[i].as(); + if (stride != nullptr) { + // otherwise, the stride is symbolic, we need to check in future with + // assumptions + if (stride->value % 16 != 0 || stride->value >= (1ULL << 40)) { + LOG(WARNING) << "TMA bulk copy cannot support a global stride of " + << desc.global_stride[i] << ", fallback to normal copy."; + return LowerNormalCopy(T, analyzer); + } + } + } + + // Smem Box + // check smem range and global range is legal + auto s_range_idx = 0; + for (size_t i = 0; i < global_range.size(); i++) { + auto g_range = global_range[i]; + if (is_one(g_range->extent)) { + continue; + } + // skip one range if it is 1 + // in case of global range is [128, 64], while shared range is [1, 128, 64] + // A_shared[0, :, :]. + while (is_one(shared_range[s_range_idx]->extent) && + s_range_idx < shared_range.size()) { + s_range_idx++; + } + if (s_range_idx >= shared_range.size()) { + LOG(FATAL) << "TMA bulk copy cannot support a global range of " + << global_range << ", shared_range " << shared_range; + } + auto s_range = shared_range[s_range_idx]; + s_range_idx++; + + ICHECK(StructuralEqual()(g_range->extent, s_range->extent)) + << global_tensor->name << "[" << i << "] is illegal, " + << global_tensor->name << "[" << i << "] = " << g_range->extent << ", " + << shared_tensor->name << "[" << s_range_idx + << "] = " << s_range->extent; + } + // TODO(lei): find a much smarter way to deduce smem box dim + // instead of using global_range + desc.smem_box = + ReverseArray(global_range.Map([](Range r) { return r->extent; })); + + desc.smem_stride = Array(desc.rank, PrimExpr(1)); + // L2 & OOB + desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + // Detect smem layout + // Shared memory swizzling is crucial for TMA performance + // It determines how data is arranged in shared memory banks to minimize bank + // conflicts Different swizzle patterns (32B, 64B, 128B) offer different + // trade-offs between access efficiency and memory usage + desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map.at(shared_tensor); + ICHECK(T.buffer_remap.count(shared_tensor)) + << "shared_tensor: " << shared_tensor->name + << " not found in buffer_remap"; + shared_tensor = T.buffer_remap.at(shared_tensor); + } + if (!shared_layout.defined()) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else if (StructuralEqual()(shared_layout, linear_layout)) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else { + ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; + auto stride = as_const_int(shared_layout->InputShape()[0]); + auto continuous = as_const_int(shared_layout->InputShape()[1]); + ICHECK(stride != nullptr && continuous != nullptr); + // We also need to check if the shape satisfies the following doc: + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout( + *stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); + } else if (StructuralEqual()( + shared_layout, + makeHalfBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (StructuralEqual()( + shared_layout, + makeFullBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } else if (StructuralEqual()( + shared_layout, + makeGemmABLayoutPadded(*stride, *continuous, + shared_tensor->dtype.bits()))) { + LOG(WARNING) << "Bulk copy cannot support a padded layout for src: " + << src->name << ", dst: " << dst->name + << ", fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } else { + LOG(WARNING) << "Came across unsupported swizzle layout for src: " + << src->name << ", dst: " << dst->name + << ", fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + } + + auto inner_box_dim = as_const_int(desc.smem_box[0]); + if (inner_box_dim == nullptr) { + LOG(WARNING) << "inner_box_dim " << desc.smem_box[0] + << " can only be a constant integer for TMA bulk copy, " + "fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + int instruction_dim = *inner_box_dim; + if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { + instruction_dim = 64 / src->dtype.bytes(); + } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { + instruction_dim = 128 / src->dtype.bytes(); + } + if (instruction_dim > 256) { + // smem_box dim must be in [0, 256] + // if is 512, we need to split the copy into two parts + ICHECK((*inner_box_dim) % 256 == 0) + << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256"; + instruction_dim = 256; + } + ICHECK((*inner_box_dim) % instruction_dim == 0) + << "inner_box_dim: " << *inner_box_dim + << " is not divisible by instruction_dim: " << instruction_dim; + desc.smem_box.Set(0, PrimExpr(instruction_dim)); + + int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); + + // Check inner_box_dim_ for each swizzle type in a cleaner way + struct SwizzleCheck { + int swizzle; + int max_dim; + }; + static const std::vector swizzle_checks = { + {static_cast(CU_TENSOR_MAP_SWIZZLE_32B), 32}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_64B), 64}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_128B), 128}, + }; + for (const auto &check : swizzle_checks) { + if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { + LOG(WARNING) << "TMA bulk copy cannot support a swizzled global layout " + "with inner_box_dim_ > " + << check.max_dim << ", will be fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + } + + Call create_descriptor = + Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + + Array args; + args.reserve(desc.rank + 4); + args.push_back(create_descriptor); + if (is_load) + args.push_back(0); // mbarrier id placeholder + auto op = is_load ? tma_load() : tma_store(); + + Stmt tma_copy; + PrimExpr total_elements = 1; + for (auto e : desc.smem_box) + total_elements *= e; + + if ((*inner_box_dim) != instruction_dim) { + Var loop_var("i"); + int loop_extent = (*inner_box_dim) / instruction_dim; + + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, + shared_offset + total_elements * loop_var, total_elements); + args.push_back(shared_addr); + global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 0; + if (!is_load) + args.push_back(need_reduce); + args.push_back(this->eviction_policy); + tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), op, args))); + } else { + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); + args.push_back(shared_addr); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 0; + if (!is_load) + args.push_back(need_reduce); + args.push_back(this->eviction_policy); + tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + } + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + + return tma_copy; +} + +Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { + ICHECK(copy_inst == CopyInst::kBulkLoad1D || + copy_inst == CopyInst::kBulkStore1D); + + // Add 1D TMA copy when the global and shared memory is contiguous + // Check if shared_tensor->name is present in T.buffer_var_gemm + // (Array) to avoid use 1D TMA copy for swizzled layout + bool is_load = copy_inst == CopyInst::kBulkLoad1D; + auto shared_range = is_load ? dst_range : src_range; + auto global_range = is_load ? src_range : dst_range; + auto shared_tensor = is_load ? dst : src; + auto global_tensor = is_load ? src : dst; + + PrimExpr shared_elements = 1; + for (size_t i = 0; i < shared_range.size(); i++) { + shared_elements *= shared_range[i]->extent; + } + + std::vector shared_strides; + PrimExpr shared_stride = 1; + for (size_t i = 0; i < shared_tensor->shape.size(); i++) { + auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; + shared_strides.insert(shared_strides.begin(), shared_stride); + shared_stride *= s; + } + + Array shared_indices; + for (auto r : shared_range) + shared_indices.push_back(r->min); + + Array global_indices; + for (auto r : global_range) { + global_indices.push_back(r->min); + } + std::vector global_strides; + PrimExpr global_stride = 1; + for (size_t i = 0; i < global_tensor->shape.size(); i++) { + auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; + global_strides.insert(global_strides.begin(), global_stride); + global_stride *= s; + } + + PrimExpr global_offset = 0; + for (size_t i = 0; i < global_indices.size(); i++) { + global_offset += global_indices[i] * global_strides[i]; + } + + PrimExpr shared_offset = 0; + for (size_t i = 0; i < shared_indices.size(); i++) { + shared_offset += shared_indices[i] * shared_strides[i]; + } + + PrimExpr elements = analyzer->Simplify(shared_elements); + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); + PrimExpr global_addr = global_tensor.access_ptr( + is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); + Stmt tma_copy; + if (is_load) { + // the zero is a placeholder for mbarrier ids + tma_copy = Evaluate( + Call(DataType::Handle(), tma_load(), + {shared_addr, global_addr, 0, + elements * shared_tensor->dtype.bytes(), this->eviction_policy})); + } else { + int need_reduce = 0; + tma_copy = Evaluate( + Call(DataType::Handle(), tma_store(), + {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), + need_reduce, this->eviction_policy})); + } + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + return tma_copy; +} +/*! + * \brief Encode the TMA descriptor into an array of PrimExpr. + * This function serializes the TMA descriptor fields into a format suitable for + * passing to the create_tma_descriptor() builtin function. The encoding follows + * the expected argument order for the TMA descriptor creation. + * \return Array of PrimExpr representing the encoded TMA descriptor. + */ +Array TMADesc::EncodeCallArgs() const { + Array args; + args.reserve(rank * 4 + 7); + + args.push_back(data_type); + args.push_back(static_cast(rank)); + args.push_back(global_addr); + for (auto e : global_shape) + args.push_back(e); + for (auto e : global_stride) + args.push_back(e); + for (auto e : smem_box) + args.push_back(e); + for (auto e : smem_stride) + args.push_back(e); + args.push_back(interleave); + args.push_back(swizzle); + args.push_back(l2_promotion); + args.push_back(oob_fill); + + return args; +} + +/** + * @brief Construct a Conv2DIm2ColOp node. + * + * Initializes a Conv2DIm2ColOpNode from raw TL-call arguments and a buffer map. + * The constructor extracts source and destination Buffers from vmap and reads + * convolution parameters encoded in args: + * - args[0]: source tensor access pointer + * - args[1]: destination tensor access pointer + * - args[2]: nhw_step (PrimExpr) + * - args[3]: c_step (PrimExpr) + * - args[4]: kernel (IntImm) + * - args[5]: stride (IntImm) + * - args[6]: dilation (IntImm) + * - args[7]: padding (IntImm) + * - args[8]: eviction_policy (IntImm) + * + * The created node stores these values (src, dst, nhw_step, c_step, kernel, + * stride, dilation, padding, eviction_policy) for later lowering to TMA-based + * GPU intrinsics. + * + * @param args Array of PrimExpr TL-call arguments (see list above). + */ +Conv2DIm2ColOp::Conv2DIm2ColOp(Array args) { + ObjectPtr node = + tvm::ffi::make_object(); + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); + node->src_ = node->srcRegion_->buffer; + node->dst_ = node->dstRegion_->buffer; + node->nhw_step_ = args[2]; + node->c_step_ = args[3]; + node->kernel_ = args[4].as().value()->value; + node->stride_ = args[5].as().value()->value; + node->dilation_ = args[6].as().value()->value; + node->padding_ = args[7].as().value()->value; + node->eviction_policy_ = args[8].as().value()->value; + data_ = std::move(node); +} + +/** + * @brief Create a shallow copy of this Conv2DIm2ColOpNode wrapped as a + * TileOperator. + * + * Produces a new Conv2DIm2ColOp that owns a freshly allocated + * Conv2DIm2ColOpNode initialized from this node (member-wise copy). This is + * used to duplicate the operator node for compiler passes that require + * independent operator instances. + * + * @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode. + */ +TileOperator Conv2DIm2ColOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Conv2DIm2ColOp(op); +} + +/** + * @brief Lower Conv2D im2col into a TMA-backed PTX sequence for Hopper. + * + * Constructs a TMA im2col descriptor from the Conv2DIm2ColOp parameters + * (kernel, stride, dilation, padding, channel/image tiling, dtype and shapes), + * emits a call to create the im2col descriptor, and returns a statement that + * invokes the corresponding tma_load_im2col builtin guarded to a single + * thread. The lowering assumes the destination resides in shared memory and the + * source in global memory and uses the provided layout information (when + * available) to select the appropriate shared-memory swizzle. + * + * Preconditions (checked with ICHECK): + * - Target is Hopper. + * - src.scope() == "global" and dst.scope() is "shared.dyn" or "shared". + * - src->shape has rank 4 and dst->shape has rank 2. + * - src and dst have the same dtype. + * - When a shared layout is supplied it must match a recognized TMA swizzle + * pattern (32B/64B/128B) or an ICHECK will fail. + * + * @param T Lowering context (target, layout map, thread_var, thread_bounds, + * buffer remapping, etc.). Used to fetch target/layout and to emit a + * thread-guarded TMA call. + * @param analyzer Arithmetic analyzer used to prove divisibility and simplify + * expressions required by descriptor construction. + * @return Stmt A TIR statement that performs a tma_load_im2col call wrapped in + * a thread-min guard (IfThenElse). The returned statement is ready + * to be inserted into the lowered TIR. + */ +Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + ICHECK(TargetIsHopper(T.target)); + ICHECK(src_.scope() == "global" && + (dst_.scope() == "shared.dyn" || dst_.scope() == "shared")); + ICHECK(src_->shape.size() == 4); + ICHECK(dst_->shape.size() == 2); + ICHECK(src_->dtype == dst_->dtype); + Layout shared_layout; + if (T.layout_map.count(dst_)) { + shared_layout = T.layout_map[dst_]; + } + + TMAIm2ColDesc desc; + desc.rank = src_->shape.size(); + desc.data_type = to_CUtensorMapDataType(src_->dtype); + desc.global_addr = src_->data; + desc.global_shape = ReverseArray(src_->shape); + + if (!src_->strides.empty()) { + desc.global_stride = ReverseArray(src_->strides); + } else { + // Create stride from shape + PrimExpr stride = 1; + desc.global_stride.reserve(desc.rank); + for (size_t i = 0; i < desc.rank; i++) { + desc.global_stride.push_back(stride); + stride *= desc.global_shape[i]; + } + } + // The first stride element should be 1 + ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; + // Make global stride in bytes + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return cast(DataType::Int(64), e) * src_->dtype.bytes(); + }); + desc.elem_stride = {1, stride_, stride_, 1}; + desc.lower_corner = {-padding_, -padding_}; + desc.upper_corner = {-padding_, -padding_}; + desc.smem_box_pixel = Downcast(dst_->shape[0])->value; + desc.smem_box_channel = Downcast(dst_->shape[1])->value; + desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); + if (!shared_layout.defined()) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else { + ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; + auto stride = as_const_int(shared_layout->InputShape()[0]); + auto continuous = as_const_int(shared_layout->InputShape()[1]); + ICHECK(stride != nullptr && continuous != nullptr); + + if (StructuralEqual()(shared_layout, + makeQuarterBankSwizzleLayout(*stride, *continuous, + dst_->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); + } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( + *stride, *continuous, + dst_->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( + *stride, *continuous, + dst_->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } else { + ICHECK(0) << "Cannot detect TMA layout."; + } + } + + Call create_desc = Call(DataType::Handle(), create_tma_im2col_descriptor(), + desc.EncodeCallArgs()); + + Array global_coords; // c, w, h, n + Array image_offset; // w, h + global_coords.reserve(desc.rank); + + ICHECK(analyzer->CanProveEqual( + FloorMod(desc.global_shape[0], desc.smem_box_channel), 0)) + << "Currently can only support divisible channel case"; + + global_coords.push_back( + FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0])); + image_offset.push_back( + dilation_ * + FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]), + kernel_)); + image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel, + desc.global_shape[0] * kernel_)); + + PrimExpr h_dim = + FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, + stride_) + + 1; + PrimExpr w_dim = + FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, + stride_) + + 1; + global_coords.push_back( + stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_); + global_coords.push_back( + stride_ * + FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) - + padding_); + global_coords.push_back( + FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim)); + + Array args; + args.reserve(desc.rank * 2 + 2); + args.push_back(create_desc); + args.push_back(0); // mbar placeholder + auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_; + auto shared_addr = dst_buffer.access_ptr(2); + args.push_back(shared_addr); + for (auto coord : global_coords) + args.push_back(coord); + for (auto offset : image_offset) + args.push_back(offset); + args.push_back(this->eviction_policy_); + Stmt tma_copy = + IfThenElse(EQ(T.thread_var, T.thread_bounds->min), + Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); + return tma_copy; +} + +/*! + * \brief Encode the TMA im2col descriptor into an array of PrimExpr. + * This function serializes the TMA im2col descriptor fields for passing to the + * create_tma_im2col_descriptor() builtin function. It includes + * convolution-specific parameters like kernel size, stride, padding, and + * dilation in addition to standard tensor descriptor fields. \return Array of + * PrimExpr representing the encoded TMA im2col descriptor. + */ +Array TMAIm2ColDesc::EncodeCallArgs() const { + Array args; + args.reserve(rank * 5 + 5); + + args.push_back(data_type); + args.push_back(static_cast(rank)); + args.push_back(global_addr); + for (auto e : global_shape) + args.push_back(e); + for (auto e : global_stride) + args.push_back(e); + for (auto e : elem_stride) + args.push_back(e); + for (auto e : lower_corner) + args.push_back(e); + for (auto e : upper_corner) + args.push_back(e); + args.push_back(smem_box_pixel); + args.push_back(smem_box_channel); + args.push_back(interleave); + args.push_back(swizzle); + args.push_back(l2_promotion); + args.push_back(oob_fill); + + return args; +} + +void CopyNode::CollectFragmentLayouts(const PrimExpr &expr, + const Map &let_var_to_expr, + const LayoutMap &existing_layouts, + PrimExpr thread_extent, + Range thread_bounds, + Map &result_map) const { + PostOrderVisit(expr, [&](const ObjectRef &node) { + if (auto bl = node.as()) { + if (bl->buffer.scope() == "local.fragment" && + !existing_layouts.count(bl->buffer) && + !result_map.count(bl->buffer)) { + auto f = Fragment::FullyReplicated(bl->buffer->shape, thread_extent); + result_map.Set(bl->buffer, f->BindThreadRange(thread_bounds)); + } + } else if (auto var_node = node.as()) { + auto var = tvm::ffi::GetRef(var_node); + if (let_var_to_expr.count(var)) { + CollectFragmentLayouts(let_var_to_expr[var], let_var_to_expr, + existing_layouts, thread_extent, thread_bounds, + result_map); + } + } + }); +} + +// Register the Copy operation with TVM's TIR system +// This makes the copy operation available for use in TVM programs +// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, +// eviction_policy +// - Marked as opaque since it has side effects (memory writes) +TIR_REGISTER_TL_TILE_OP(Copy, copy) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +/** + * @brief Layout inference hook for Conv2DIm2ColOpNode. + * + * This operator does not provide any layout inference; the function + * intentionally returns an empty LayoutMap to indicate no layout suggestions. + * + * @param T Context for layout inference (ignored). + * @param level Inference level (ignored). + * @return LayoutMap An empty map. + */ +LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + return {}; +} + +// Register the Conv2DIm2Col operation with TVM's TIR system +// This operation performs im2col transformation for 2D convolutions using TMA +// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, +// dilation, padding, eviction_policy +// - Marked as opaque since it has side effects (memory writes) +TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col) + .set_num_inputs(9) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + CopyNode::RegisterReflection(); + Conv2DIm2ColOpNode::RegisterReflection(); +} +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/copy.h b/tilelang/original/src/op/copy.h new file mode 100644 index 0000000000000000000000000000000000000000..aca629f5c6e0ed4a8b36ab2d5b646e4648e4aefb --- /dev/null +++ b/tilelang/original/src/op/copy.h @@ -0,0 +1,378 @@ +/*! + * \file tl/op/copy.h + * \brief Copy operations and Tensor Memory Access (TMA) descriptors + */ + +#ifndef TVM_TL_OP_COPY_H_ +#define TVM_TL_OP_COPY_H_ + +#include "operator.h" +#include "parallel.h" + +namespace tvm { +namespace tl { +using namespace tir; + +/// Copy instruction types for different memory access patterns +enum class CopyInst : uint8_t { + kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy + kLDSM = 1, // ldmatrix memory copy + kSTSM = 2, // stmatrix memory copy + kBulkLoad = 3, // utilize tma load + kBulkStore = 4, // utilize tma store + // we should separate the bulk load and store for 1d and multi-dim + // as they have different memory access patterns + kBulkLoad1D = 5, // utilize tma load 1d + kBulkStore1D = 6, // utilize tma store 1d + kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) + kTMemStore = 8, // tcgen05.st (register -> tensor memory) +}; + +/// Descriptor for Tensor Memory Access (TMA) copy operations +struct TMADesc { + size_t rank; ///< Tensor rank (number of dimensions) + int data_type; ///< Data type identifier + Array global_shape; ///< Shape in global memory + Array global_stride; ///< Strides in global memory + Array smem_box; ///< Block shape in shared memory + Array smem_stride; ///< Strides in shared memory + PrimExpr global_addr; ///< Base address in global memory + int swizzle; ///< Memory layout swizzle parameter + int interleave; ///< Memory interleave parameter + int oob_fill; ///< Out-of-bound fill policy + int l2_promotion; ///< L2 cache promotion flag + + /// Encode descriptor fields into runtime call arguments + Array EncodeCallArgs() const; +}; + +/*! + * \brief Descriptor for TMA-based im2col transformation used in Conv2D. + * + * This supports extracting patches from the input image (im2col) + * for convolution lowering, storing them in shared memory. + */ +struct TMAIm2ColDesc { + size_t rank; // Rank of the tensor + int data_type; // Data type identifier + Array global_shape; // Shape of input tensor in global memory + Array global_stride; // Stride in global memory + Array elem_stride; // Stride at element level (per axis) + Array lower_corner; // Lower bound offsets for the extraction window + // (rank - 2 dims) + Array upper_corner; // Upper bound offsets for the extraction window + // (rank - 2 dims) + PrimExpr global_addr; // Base address in global memory + int smem_box_pixel; // Pixel dimension of shared memory box + int smem_box_channel; // Channel dimension of shared memory box + int swizzle; // Memory swizzle setting + int interleave; // Memory interleaving setting + int oob_fill; // Out-of-bound fill policy + int l2_promotion; // Whether to enable L2 cache promotion + + /*! + * \brief Encode descriptor fields into runtime arguments. + */ + Array EncodeCallArgs() const; +}; + +/*! + * \brief Get TVM Op handle for Conv2DIm2Col. + */ + +/*! + * \brief Clone this Conv2DIm2Col operator. + * + * Returns a TileOperator reference that is a shallow clone of this operator. + */ +class CopyNode : public TileOperatorNode { +public: + Buffer src, dst; // Source and destination buffers + Array src_range, dst_range; // Ranges for each dimension in src and dst + IntImm coalesced_width; // Width (in elements) for coalesced memory access + Bool disable_tma = Bool(false); // Whether to disable TMA acceleration + + mutable ParallelOp par_op_; // Optional associated parallelization operator + + enum class EvictionPolicy : uint8_t { + kEvictNormal = 0, + kEvictFirst = 1, + kEvictLast = 2, + }; + + uint8_t eviction_policy; // Policy for cache eviction + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Copy", CopyNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &CopyNode::src) + .def_ro("dst", &CopyNode::dst) + .def_ro("src_range", &CopyNode::src_range) + .def_ro("dst_range", &CopyNode::dst_range) + .def_ro("coalesced_width", &CopyNode::coalesced_width); + } + + /*! + * \brief Lower the copy operator to a TIR statement. + * \param T Arguments for lowering. + * \param analyzer Analyzer for simplification and bounds checks. + */ + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + + /*! + * \brief Infer buffer layouts after applying this operator. + * \param T Arguments for layout inference. + * \param level Level of inference (basic or detailed). + */ + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + /*! + * \brief Check if bulk copy is supported. + */ + bool CheckBulkLoad(Target target, arith::Analyzer *analyzer, + bool check_last_dim = true) const; + + /*! + * \brief Check if bulk store is supported. + */ + bool CheckBulkStore(Target target, arith::Analyzer *analyzer, + bool check_last_dim = true) const; + + /*! + * \brief Check if bulk copy 1d load is supported. + */ + bool CheckBulkLoad1D(Target target, const LayoutMap &layout_map, + arith::Analyzer *analyzer) const; + + /*! + * \brief Check if bulk copy 1d store is supported. + */ + bool CheckBulkStore1D(Target target, const LayoutMap &layout_map, + arith::Analyzer *analyzer) const; + + /*! + * \brief Check if bulk copy 1d is supported. + */ + bool CheckBulkCopy1D(const Buffer &global_tensor, const Buffer &shared_tensor, + const Array &global_range, + const Array &shared_range, + const LayoutMap &layout_map, + arith::Analyzer *analyzer) const; + + /*! + * \brief Check if lds memory copy is supported. + */ + bool CheckLDSMCopy(Target target) const; + + /*! + * \brief Check if stsm memory copy is supported. + */ + bool CheckSTSMCopy(Target target) const; + + /*! + * \brief Check if tensor memory load is supported. + */ + bool CheckTMemLoad(Target target) const; + + /*! + * \brief Check if tensor memory store is supported. + */ + bool CheckTMemStore(Target target) const; + + /*! + * \brief Get the copy instruction type. + */ + CopyInst GetCopyInst(Target target, bool disable_tma_lower, + const LayoutMap &layout_map, arith::Analyzer *analyzer, + bool buffer_oob) const; + +protected: + /*! + * \brief Generate lowering for bulk/global-to-shared copy. + */ + Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const; + + /*! + * \brief Generate lowering for bulk copy 1d. + */ + Stmt LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const; + + /*! + * \brief Generate lowering for LDS Memory Copy (shared memory to shared + * memory or smem usage). + */ + Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const; + + /*! + * \brief Generate lowering for tensor memory copy (tcgen05.ld/st/cp). + */ + Stmt LowerTmemCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + + /*! + * \brief Generate lowering for normal copy. + */ + Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + + /*! + * \brief Generate SIMT (thread-level) loop for copying. + */ + For MakeSIMTLoop(arith::Analyzer *analyzer) const; + + /*! + * \brief Compute linear layout for tma copy. + */ + Layout ComputeLinearLayout(const Buffer &shared_tensor) const; + + /*! + * \brief Create iterator variables for multi-dimensional copy loops. + */ + Array MakeIterVars() const; + + /*! + * \brief Calculate source or destination indices from iteration vars. + * \param ivs Iterator variables from MakeIterVars(). + * \param src_dst 0 = make source indices, 1 = make destination indices. + */ + Array MakeIndices(const Array &ivs, int src_dst) const; + + /*! + * \brief Construct the boundary predicate for valid copy (to avoid OOB). + * \param analyzer Arithmetic analyser for simplification. + * \param ivs Iterator variables. + * \param extents Extent expressions for the relevant buffer. + * \param src_dst 0 = predicate for source, 1 = predicate for destination. + */ + PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, + Array extents, int src_dst) const; + + /** + * \brief Create a deep copy of this operator. + * + * Returns a TileOperator that is a copy of the current node, preserving all + * configuration (buffers, parameters, and layout-related fields). + * @return A TileOperator owning the cloned operator node. + */ + + /** + * \brief Constructor. + * \param args Expression arguments for the Conv2D im2col operator. + * \param vmap Buffer variable mapping. + */ + + /** + * \brief Get the TVM Op handle corresponding to this Conv2DIm2Col operator. + * @return Reference to the singleton TVM Op representing this operator. + */ + TileOperator Clone() const; + +private: + /*! + * \brief Collect fragment buffers from expression and create fully replicated + * layouts. + * + * Recursively searches the expression for BufferLoad nodes with + * "local.fragment" scope, following let bindings. For each found fragment + * buffer, creates a fully replicated layout and adds it to result_map. + * + * \param expr Expression to search. + * \param let_var_to_expr Map from let variables to their bound expressions. + * \param existing_layouts Existing layout map to check for already-inferred + * layouts. \param thread_extent Number of threads for replication. \param + * thread_bounds Thread bounds for binding the layout. \param result_map + * Output map to store collected fragment layouts. + */ + void CollectFragmentLayouts(const PrimExpr &expr, + const Map &let_var_to_expr, + const LayoutMap &existing_layouts, + PrimExpr thread_extent, Range thread_bounds, + Map &result_map) const; +}; + +class Copy : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Copy, TileOperator, CopyNode); + + /*! + * \brief Constructor. + * \param args Expression arguments for the copy. + * \param vmap Buffer variable mapping. + */ + TVM_DLL Copy(Array args); + + /*! + * \brief Get the TVM Op handle corresponding to this Copy op. + */ + static const Op &Get(); +}; + +/*! + * \brief Special operator for Conv2D im2col transformation. + * + * This operator converts input image layout into columnar format suitable + * for matrix multiplication-based convolution lowering. + */ +class Conv2DIm2ColOpNode : public TileOperatorNode { +public: + BufferRegion srcRegion_, dstRegion_; + Buffer src_, + dst_; // Source (input feature map) and destination (im2col matrix) + int stride_; // Stride for convolution + int padding_; // Padding amount + int dilation_; // Dilation factor + int kernel_; // Kernel size + int eviction_policy_; // Cache eviction policy + PrimExpr nhw_step_; // Step size in NHW dimensions + PrimExpr c_step_; // Step size in channel dimension + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("srcRegion", &Conv2DIm2ColOpNode::srcRegion_) + .def_ro("dstRegion", &Conv2DIm2ColOpNode::dstRegion_) + .def_ro("src", &Conv2DIm2ColOpNode::src_) + .def_ro("dst", &Conv2DIm2ColOpNode::dst_) + .def_ro("stride", &Conv2DIm2ColOpNode::stride_) + .def_ro("padding", &Conv2DIm2ColOpNode::padding_) + .def_ro("dilation", &Conv2DIm2ColOpNode::dilation_) + .def_ro("kernel", &Conv2DIm2ColOpNode::kernel_) + .def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy_); + } + + /*! + * \brief Lower to TIR statement. + */ + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + + /*! + * \brief Infer layout for this operator. + */ + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + /*! + * \brief Get TVM Op handle. + */ + static const Op &Get(); + TileOperator Clone() const; +}; + +class Conv2DIm2ColOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator, + Conv2DIm2ColOpNode); + TVM_DLL Conv2DIm2ColOp(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_COPY_H_ diff --git a/tilelang/original/src/op/fill.cc b/tilelang/original/src/op/fill.cc new file mode 100644 index 0000000000000000000000000000000000000000..794b38401e08351ad10faf29f3336a6fe494267f --- /dev/null +++ b/tilelang/original/src/op/fill.cc @@ -0,0 +1,230 @@ +/*! + * \file tl/op/fill.cc + * + * Define elment-wise operators. + */ + +#include "fill.h" + +#include +#include +#include + +#include "../layout/tcgen05_layout.h" +#include "../target/utils.h" +#include "../transform/common/loop_fusion_utils.h" +#include "../transform/common/loop_parallel_transform_utils.h" +#include "../transform/loop_partition.h" +#include "../transform/loop_vectorize.h" +#include "builtin.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a Fill operator node from call arguments and a buffer map. + * + * This constructor builds a FillNode describing an element-wise fill of a + * destination buffer region with a scalar/vector value and stores it in + * `data_`. + * + * Detailed behavior: + * - If `args[0]` is a `BufferLoad`, the loaded buffer becomes the destination + * and the load indices are converted to per-dimension ranges: + * - `Ramp(base, lanes, stride)` is converted to `Range(base, lanes)`. Only + * stride == 1 and constant `lanes` are supported. + * - Non-ramp indices become `Range(index, 1)`. + * - Otherwise `args[0]` is treated as an access pointer; the destination buffer + * is resolved via `vmap[GetVarFromAccessPtr(args[0])]` and the region is the + * full buffer shape for each dimension. + * - `args[1]` is used as the fill value; it is cast to the destination buffer's + * dtype if necessary. + * - Performs validation: + * - Region dimensionality must match destination rank. + * - For statically-known region mins and extents, checks that mins >= 0 and + * extents do not exceed the corresponding destination shape extents. + * + * Parameters: + * @param args Call arguments: expected layout is [dst_access_or_bufferload, + * value]. + * - args[0]: destination access (BufferLoad or pointer expression). + * - args[1]: value to fill (scalar or vector). + * + * Notes: + * - The constructor enforces constraints (e.g., stride == 1 ramps, constant + * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out + * of bounds. + */ +Fill::Fill(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + + BufferRegion region = NormalizeToBufferRegion(args[0]); + node->dst = region->buffer; + node->region = region->region; + + if (args[1]->dtype != node->dst->dtype) { + node->value = Cast(node->dst->dtype, args[1]); + } else { + node->value = args[1]; + } + + ICHECK(node->region.size() == node->dst->shape.size()) + << "region size = " << node->region.size() + << " != " << node->dst->shape.size(); + for (int i = 0; i < node->region.size(); i++) { + // bound check if region is static + if (const auto *min_imm = node->region[i]->min.as()) { + int64_t min = min_imm->value; + ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; + } + if (const auto *extent_imm = node->region[i]->extent.as()) { + // Only perform the upper-bound check when the destination shape + // extent is also statically known. If the shape is symbolic (e.g., Var), + // skip this static check to avoid invalid downcasts. + if (const auto *shape_imm = node->dst->shape[i].as()) { + ICHECK_LE(extent_imm->value, shape_imm->value) + << "region[" << i << "] = " << extent_imm->value << " > " + << node->dst->shape[i]; + } + } + } + data_ = std::move(node); +} + +/** + * @brief Create a copy of this FillNode and return it as a TileOperator. + * + * Constructs a new FillNode by copying the current node and wraps the copy in a + * Fill TileOperator. + * + * @return TileOperator A TileOperator that owns the copied FillNode. + */ +TileOperator FillNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Fill(op); +} + +/** + * @brief Build a SIMT-style nested parallel loop that fills the destination + * buffer. + * + * Constructs per-dimension data-parallel loop iterators matching this node's + * region extents, emits a BufferStore that writes the node's `value` into `dst` + * at the loop indices, and nests the loops (innermost to outermost) as parallel + * `For` nodes. Returns the outermost `For` loop representing the complete + * multi-dimensional fill kernel. + * + * @return For Outermost parallel `For` loop of the generated nested SIMT loop. + */ +For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { + int ndim = dst->shape.size(); + Array loop_vars; + Array dst_indices; + for (int i = 0; i < ndim; i++) { + Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype); + loop_vars.push_back({region[i], var, IterVarType::kDataPar}); + // Offset the loop induction variable by region min to honor sliced regions + dst_indices.push_back(region[i]->min + var); + } + Stmt body = BufferStore(dst, value, dst_indices); + for (int i = ndim - 1; i >= 0; i--) { + body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, + ForKind::kParallel, body); + } + return Downcast(body); +} + +/** + * @brief Lower this Fill operator to a TIR statement for the target. + * + * Lowers the FillNode into a Stmt according to the destination buffer scope: + * - "local.fragment" and shared ("shared", "shared.dyn"): create a parallel + * operation from a SIMT loop, infer its layout, partition the root loop by + * the thread variable, vectorize the resulting thread loop, and, if a + * per-thread predicate exists, guard the vectorized loop with that + * predicate. + * - "local": build a SIMT loop and return its vectorized form. + * - other scopes: fatal error. + * + * The lowering may query layout and thread information from @p T and uses the + * provided analyzer for any required arithmetic/layout analysis. + * + * @param T Lowering arguments (target, thread bounds, thread var, layout map). + * @return Stmt The lowered TIR statement implementing the fill. + */ +Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + if (dst.scope() == "local.fragment") { + auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + InferLevel::kFree); + auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, + par_op->GetLoopLayout()); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); + if (par_op->GetPredicate(T.thread_var).defined()) { + return IfThenElse(par_op->GetPredicate(T.thread_var).value(), + vectorized_thread_loop); + } + return vectorized_thread_loop; + } else if (dst.scope() == "local") { + auto init_loop = MakeSIMTLoop(analyzer); + auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer); + return vectorized_thread_loop; + } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || + dst.scope() == "global") { + auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + InferLevel::kFree); + auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, + par_op->GetLoopLayout()); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); + if (par_op->GetPredicate(T.thread_var).defined()) { + return IfThenElse(par_op->GetPredicate(T.thread_var).value(), + vectorized_thread_loop); + } + return vectorized_thread_loop; + } else { + LOG(FATAL) << "Unsupported scope " << dst.scope(); + return Stmt(); + } +} + +/** + * @brief Infer memory/layout mapping for the Fill operator. + * + * Returns the layout mapping produced by layout inference for this FillNode. + * Currently no layout inference is performed for Fill and the function returns + * an empty LayoutMap. + * + * @param T Context required for layout inference (unused). + * @param level The inference level requested (unused). + * @return LayoutMap Empty map indicating no inferred layouts for this operator. + */ +LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + return {}; +} + +TIR_REGISTER_TL_TILE_OP(Fill, fill) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); } + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/fill.h b/tilelang/original/src/op/fill.h new file mode 100644 index 0000000000000000000000000000000000000000..c10a5cfb1b516cf305feb7c58a29c8d20a09e6ae --- /dev/null +++ b/tilelang/original/src/op/fill.h @@ -0,0 +1,55 @@ +/*! + * \file tl/op/fill.h + * \brief Fill operations for tensor initialization + */ + +#ifndef TVM_TL_OP_FILL_H_ +#define TVM_TL_OP_FILL_H_ + +#include "operator.h" +#include "parallel.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/// Node class for fill operations +class FillNode : public TileOperatorNode { +public: + tir::Buffer dst; ///< Destination buffer to fill + PrimExpr value; ///< Value to fill with + Array region; ///< Region to fill within the buffer + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fill", FillNode, TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + static const Op &Get(); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dst", &FillNode::dst) + .def_ro("value", &FillNode::value) + .def_ro("region", &FillNode::region); + } + + TileOperator Clone() const; + +private: + /// Create SIMT-style parallel loop for filling + For MakeSIMTLoop(arith::Analyzer *analyzer) const; +}; + +/// Wrapper class for fill operations +class Fill : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode); + TVM_DLL Fill(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_FILL_H_ diff --git a/tilelang/original/src/op/finalize_reducer.cc b/tilelang/original/src/op/finalize_reducer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f542b2d917b5b05999f1275cb96b533d73a947b4 --- /dev/null +++ b/tilelang/original/src/op/finalize_reducer.cc @@ -0,0 +1,169 @@ +/*! + * \file src/op/finalize_reducer.cc + * + * Define finalize_reducer operator. + */ + +#include "finalize_reducer.h" + +#include +#include +#include +#include + +#include "../target/utils.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a FinalizeReducerOp from TL operator arguments and a buffer + * map. + * + * Extracts the reducer Buffer from `vmap` using the variable referenced by + * `args[0]` and sets the reduction operation type from the integer code in + * `args[1]`. + * + * @param args TL operator arguments: expects at least two elements where + * `args[0]` is an access pointer identifying the reducer variable + * and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min). + */ +FinalizeReducerOp::FinalizeReducerOp(Array args) { + auto node = tvm::ffi::make_object(); + // Normalize any supported region expression + // (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the + // underlying Buffer as reducer. + auto region = NormalizeToBufferRegion(args[0]); + node->reducer = region->buffer; + node->op = (ReducerOpType)*as_const_int(args[1]); + data_ = std::move(node); +} + +/** + * @brief Lower the finalize_reducer TL operator to a TIR statement. + * + * Lowers the operator that finalizes a reducer by performing a thread-wide + * AllReduce across the reducer's output elements and writing the reduced value + * back into the reducer buffer. The function: + * - Fetches the reducer buffer and expects its layout to be a Fragment. + * - Builds index Vars for each output dimension. + * - Reads the layout's ReplicateExtent and: + * - if extent == 1, emits a no-op Evaluate(0); + * - otherwise constructs an AllReduce extern call (uses `run_hopper` when the + * compilation target is Hopper) with an optional workspace (allocated via + * T.AddWorkspace when reducing_threads >= 32) and stores the result via + * BufferStore. + * - Wraps the store in parallel outer For loops over each output dimension. + * + * @param T Lowering context containing buffer remapping, layout map, thread + * bounds, target, and helper methods (e.g., AddWorkspace). + * @param analyzer Arithmetic analyzer (unused by this implementation but + * provided for consistency with lowering API). + * @return Stmt The lowered TIR statement representing the AllReduce and + * surrounding loops. + * + * @note The function ICHECKs that the reducer layout is present and a Fragment, + * and that ReplicateExtent is either 1 or equal to the thread block + * extent; violations cause a fatal check failure. + */ +Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + auto buffer = T.buffer_remap[reducer]; + auto opt_layout = T.layout_map.Get(reducer); + ICHECK(opt_layout); + ICHECK(opt_layout->as()); + auto layout = opt_layout->as().value(); + Array indices_0; + indices_0.reserve(layout->OutputDim()); + for (int i = 0; i < layout->OutputDim(); ++i) + indices_0.push_back(Var("__finred_" + std::to_string(i))); + + const int64_t *p_extent = as_const_int(layout->ReplicateExtent()); + ICHECK(p_extent); + int extent = *p_extent, scale = 1; + ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent)) + << "Illegal finalize_reducer: extent=" << extent + << "; T.thread_bounds=" << T.thread_bounds; + + if (extent == 1) + return Evaluate(0); + + std::array op_names{"tl::SumOp", "tl::MaxOp", "tl::MinOp"}; + auto op_str = op_names[(int)op]; + + // adopted from ReduceOp + int reducing_threads = extent; + std::stringstream ss; + auto thread_offset = T.thread_bounds->min; + if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) { + auto all_threads = T.thread_bounds->extent; + ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 + << ", " << thread_offset << ", " << all_threads << ">::run_hopper"; + } else { + ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 + << ", " << thread_offset << ">::run"; + } + Array thread_reduce_args = {StringImm(ss.str()), + BufferLoad(buffer, indices_0)}; + if (reducing_threads >= 32) { + PrimExpr workspace = + T.AddWorkspace(*as_const_int(T.thread_bounds->extent), buffer->dtype); + thread_reduce_args.push_back(workspace); + } + auto call = Call(buffer->dtype, builtin::call_extern(), thread_reduce_args); + Stmt body = BufferStore(buffer, call, indices_0); + + // make the outer spatial loop + for (int i = layout->OutputDim() - 1; i >= 0; i--) { + body = For(indices_0[i].as().value(), 0, layout->OutputShape()[i], + ForKind::kParallel, body); + } + + return body; +} + +/** + * @brief Infer and return the layout mapping for the reducer buffer. + * + * Copies the existing layout for the reducer from the provided LayoutInferArgs + * into a new LayoutMap and returns it. The inference does not modify the + * layout; it preserves the reducer's current layout. + * + * @param T Provides the input layout map from which the reducer's layout is + * copied. + * @param level Unused by this operator; present for API compatibility. + * @return LayoutMap A map that contains the reducer buffer mapped to its + * original layout. + */ +LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + LayoutMap layout_map; + layout_map.Set(reducer, T.layout_map.Get(reducer).value()); + return layout_map; +} + +/** + * @brief Create a deep copy of this FinalizeReducerOpNode and wrap it as a + * TileOperator. + * + * Constructs a new FinalizeReducerOpNode by copying the current node state and + * returns a TileOperator that owns the copied node. + * + * @return TileOperator A TileOperator that contains a deep copy of this node. + */ +TileOperator FinalizeReducerOpNode::Clone() const { + auto node = tvm::ffi::make_object(*this); + return TileOperator(node); +} + +TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { FinalizeReducerOpNode::RegisterReflection(); } +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/finalize_reducer.h b/tilelang/original/src/op/finalize_reducer.h new file mode 100644 index 0000000000000000000000000000000000000000..99e1e7cbfd516880a8c7411575663c3623f94136 --- /dev/null +++ b/tilelang/original/src/op/finalize_reducer.h @@ -0,0 +1,58 @@ +// Copyright (c) Tile-AI Corporation. +// Licensed under the MIT License. + +/*! + * \file src/op/finalize_reducer.h + * \brief Define finalize_reducer operator. + */ + +#ifndef TVM_TL_OP_FINALIZE_REDUCER_H_ +#define TVM_TL_OP_FINALIZE_REDUCER_H_ + +#include "../transform/layout_reducer.h" +#include "./operator.h" + +/** + * Get the Op singleton for the public FinalizeReducerOp handle. + * + * @return A reference to the Op describing FinalizeReducer. + */ +namespace tvm { +namespace tl { + +using namespace tir; + +class FinalizeReducerOpNode : public TileOperatorNode { +public: + tir::Buffer reducer; + ReducerOpType op; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.FinalizeReducerOp", + FinalizeReducerOpNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("reducer", &FinalizeReducerOpNode::reducer) + .def_ro("op", &FinalizeReducerOpNode::op); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + static const Op &Get(); + TileOperator Clone() const; +}; + +class FinalizeReducerOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator, + FinalizeReducerOpNode); + TVM_DLL FinalizeReducerOp(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_FINALIZE_REDUCER_H_ diff --git a/tilelang/original/src/op/gemm.cc b/tilelang/original/src/op/gemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..7c5058266e910d1287719058342e748b4b91746f --- /dev/null +++ b/tilelang/original/src/op/gemm.cc @@ -0,0 +1,845 @@ +/*! + * \file tl/op/gemm.cc + * \brief Implementation of General Matrix Multiplication (GEMM) operators + */ + +#include "gemm.h" +#include "builtin.h" +#include +#include +#include +#include +#include + +#include "../target/utils.h" +#include "tcgen5_meta.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a Gemm operator from serialized TL arguments and a buffer + * map. + * + * This constructor deserializes operator parameters from `args` and resolves + * buffer references via `vmap`, populating an internal GemmNode with: + * - device pointers for A, B, C and their corresponding Buffer objects, + * - transpose flags for A and B, + * - matrix dimensions M, N, K, + * - warp allocation policy and clear_accum flag, + * - strides and memory offsets for A and B, + * - optional kPack (must be 1 or 2) and optional wg_wait. + * + * The populated GemmNode is stored into the wrapper's internal `data_`. + * + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] + * + * @note If `kPack` is provided it must be 1; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here. + */ +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} + +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} + +Gemm::Gemm(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->bRegion_ = NormalizeToBufferRegion(args[1]); + node->cRegion_ = NormalizeToBufferRegion(args[2]); + + node->a_ = node->aRegion_->buffer; + node->b_ = node->bRegion_->buffer; + node->c_ = node->cRegion_->buffer; + node->transA_ = args[3].as().value(); + node->transB_ = args[4].as().value(); + node->m_ = args[5].as().value()->value; + node->n_ = args[6].as().value()->value; + node->k_ = args[7].as().value()->value; + node->policy_ = GemmWarpPolicy(args[8].as().value()->value); + node->clearAccum_ = args[9].as().value(); + node->strideA_ = args[10].as().value()->value; + node->strideB_ = args[11].as().value()->value; + node->offsetA_ = args[12].as().value()->value; + node->offsetB_ = args[13].as().value()->value; + if (args.size() > 14) { + node->kPack_ = args[14].as().value()->value; + if (node->kPack_ != 1 && node->kPack_ != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 15) { + node->wgWait_ = args[15].as().value()->value; + } + if (args.size() > 16) { + if (const auto *load = args[16].as()) { + node->mbarRegion_ = + NormalizeToBufferRegion(Downcast(args[16])); + node->mbar_ = node->mbarRegion_->buffer; + } else { + node->mbar_ = std::nullopt; + } + } + node->cCoords_ = Array( + {args[17].as().value(), args[18].as().value()}); + data_ = std::move(node); +} + +/** + * @brief Create a copy of this GemmNode as a TileOperator. + * + * Constructs a new GemmNode by copying the current node state and returns it + * wrapped in a Gemm TileOperator. + * + * @return TileOperator A Gemm operator that owns a copy of this node. + */ +TileOperator GemmNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Gemm(op); +} + +bool GemmNode::allowTcgen5Mma(Target target) const { + return TargetIsSm100(target) && + ((a_.scope() == "shared.dyn" || a_.scope() == "shared" || + a_.scope() == "shared.tmem") && + (b_.scope() == "shared.dyn" || b_.scope() == "shared") && + c_.scope() == "shared.tmem") && + GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first; +} + +bool GemmNode::allowWgmma(int block_size, Target target) const { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + + int warp_size = TargetGetWarpSize(target); + int num_warps = block_size / warp_size; + return !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && + TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) && + checkWgmma(); +} + +GemmInst GemmNode::getGemmInst(int block_size, Target target) const { + if (allowTcgen5Mma(target)) { + return GemmInst::kTCGEN5MMA; + } else if (allowWgmma(block_size, target)) { + return GemmInst::kWGMMA; + } else if(TargetIsDCU(target)) { + return GemmInst::KMMAC; + } else if (TargetIsCDNA(target)) { + return GemmInst::kMFMA; + } else if (TargetIsCuda(target)) { + return GemmInst::kMMA; + } else { + ICHECK(0) << "Unsupported target for gemm: " << target; + return GemmInst::kMMA; + } +} + +std::pair GemmWarpPolicyNode::computeWarpPartition( + int M, int N, int block_size, Target target, GemmInst gemm_inst) const { + int num_warps = block_size / TargetGetWarpSize(target); + if (gemm_inst == GemmInst::kTCGEN5MMA) { + return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning + } + + int m_warp = 1, n_warp = 1; + constexpr int kMPerWarp = 16; // Rows processed by a single warp + int kNPerWarp = 8; // Columns processed by a single warp + if (TargetIsVolta(target)) { + kNPerWarp = 16; + } + ICHECK(M % kMPerWarp == 0) + << "M must be divisible by " << kMPerWarp << ", but got " << M; + ICHECK(N % kNPerWarp == 0) + << "N must be divisible by " << kNPerWarp << ", but got " << N; + + if (gemm_inst == GemmInst::kWGMMA) { + ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads."; + + constexpr int kGroup = 4; // Number of warps in a warp-group + + m_warp = kGroup; // Initially, only one warp-group on M dimension + n_warp = num_warps / m_warp; // Rest all on N dimension + + if (this->isFullRow()) { + // Try to put as many warp-groups as possible on M dimension + // (decreasing multiples of 4, ensuring divisibility by M) + for (int cand = num_warps; cand >= kGroup; cand -= kGroup) { + if (M % (cand * kMPerWarp) == 0) { + m_warp = cand; + n_warp = num_warps / m_warp; + break; + } + } + } else if (this->isFullCol()) { + // Try to use warps on N dimension; if N is not divisible, split excess + // groups to M + int cand_n = n_warp; // Initially assume all on N + if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails + int max_n = N / kNPerWarp; + // Find a feasible n_warp from max possible downwards, ensuring + // num_warps/n_warp is multiple of 4 + for (int n = std::min(cand_n, max_n); n >= 1; --n) { + if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) { + n_warp = n; + m_warp = num_warps / n_warp; + break; + } + } + } + } else if (this->isSquare()) { + // Exhaustive search, but m must be multiple of 4 + int max_m = M / kMPerWarp; + int max_n = N / kNPerWarp; + + float ideal = N > 0 ? static_cast(M) / N : 1.f; + + float best_score = std::numeric_limits::max(); + int best_m = kGroup, best_n = n_warp; + + for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) { + if (num_warps % m) + continue; + int n = num_warps / m; + if (n > max_n) + continue; + + float m_per_warp = static_cast(M) / (m * kMPerWarp); + float n_per_warp = static_cast(N) / (n * kNPerWarp); + float score = std::abs(m_per_warp / n_per_warp - ideal); + + if (score < best_score) { + best_score = score; + best_m = m; + best_n = n; + } + } + m_warp = best_m; + n_warp = best_n; + } else { + ICHECK(0) << "Unknown GemmWarpPolicy"; + } + + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp + << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + + // Store the computed values in the object's member variables + this->m_warp = m_warp; + this->n_warp = n_warp; + + return {m_warp, n_warp}; + } + + if (this->isFullRow()) { + // Try to partition M first + m_warp = num_warps; + n_warp = 1; + + // If M cannot be evenly divided by m_warp*16, try to split remaining warps + // to N + if (M % (m_warp * kMPerWarp) != 0) { + // Calculate how many warps we can use for M + int max_m_warps = M / kMPerWarp; + m_warp = max_m_warps; + // Use remaining warps for N + n_warp = num_warps / m_warp; + if (n_warp == 0) + n_warp = 1; + } + } else if (this->isFullCol()) { + // Try to partition N first + m_warp = 1; + n_warp = num_warps; + + // If N cannot be evenly divided by n_warp*8, try to split remaining warps + // to M + if (N % (n_warp * kNPerWarp) != 0) { + // Calculate how many warps we can use for N + int max_n_warps = N / kNPerWarp; + n_warp = max_n_warps; + // Use remaining warps for M + m_warp = num_warps / n_warp; + if (m_warp == 0) + m_warp = 1; + } + } else if (this->isSquare()) { + // First calculate the maximum possible warps for each dimension + int max_m_warps = + M / kMPerWarp; // Each warp needs at least 16 elements in M + + // Calculate the ideal ratio of M/N warps based on the matrix dimensions + float ideal_ratio = 1.0f; + if (N > 0) { + ideal_ratio = static_cast(M) / N; + } + + // Try to find the best balanced partition + int best_m = 1; + int best_n = 1; + float best_balance = std::numeric_limits::max(); + // Try all possible combinations that satisfy the constraints + for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { + int n = num_warps / m; + + // Calculate how balanced this partition is + float m_per_warp = static_cast(M) / (m * kMPerWarp); + float n_per_warp = static_cast(N) / (n * kNPerWarp); + // m_per_warp and n_per_warp must be greater than 1 + if (m_per_warp < 1 || n_per_warp < 1) + continue; + // m * n must equal num_warps + if (m * n != num_warps) + continue; + + float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); + + if (balance < best_balance) { + best_balance = balance; + best_m = m; + best_n = n; + } + } + + m_warp = best_m; + n_warp = best_n; + } else { + ICHECK(0) << "Unknown GemmWarpPolicy"; + } + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp + << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + + // Store the computed values in the object's member variables + this->m_warp = m_warp; + this->n_warp = n_warp; + + return {m_warp, n_warp}; +} + +/** + * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. + * + * Evaluates device-memory placement, data-type combinations, transpose flags, + * and K divisibility constraints required for the Hopper WGMMA code path. + * + * The check returns true only when: + * - B resides in shared memory ("shared" or "shared.dyn"); and + * - (C, A, B) dtypes match one of the supported combinations below and K + * satisfies the required alignment; and + * - for combinations that require specific orientations, A is not transposed + * and B is transposed. + * + * Supported combinations and constraints: + * - C=float16: + * - A=float16, B=float16: K % 16 == 0 + * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % + * 32 == 0 + * - C=float32: + * - A=float16, B=float16: K % 16 == 0 + * - A=bfloat16, B=bfloat16: K % 16 == 0 + * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 + * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 + * - C=int32: + * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) + * and K % 32 == 0 + * + * @return true if WGMMA is supported for the current buffers, dtypes, and + * transpose/shape constraints; false otherwise. + */ +bool GemmNode::checkWgmma() const { + if (b_.scope() != "shared.dyn" && b_.scope() != "shared") { + return false; + } + + if (c_->dtype == DataType::Float(16)) { + if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) + return k_ % 16 == 0; + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) + return (!transA_) && transB_ && k_ % 32 == 0; + else + return false; + } else if (c_->dtype == DataType::Float(32)) { + if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) + return k_ % 16 == 0; + else if (a_->dtype == DataType::BFloat(16) && + b_->dtype == DataType::BFloat(16)) + return k_ % 16 == 0; + else if (a_->dtype == DataType::Float(32) && + b_->dtype == DataType::Float(32)) + return (!transA_) && transB_ && k_ % 8 == 0; + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) + return (!transA_) && transB_ && k_ % 32 == 0; + else + return false; + } else if (c_->dtype == DataType::Int(32)) { + if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else + return false; + } else { + return false; + } +} + +/** + * @brief Parse and return the numeric GPU architecture from a Target's "arch" + * attribute. + * + * Examines the target's "arch" string and, if it matches the pattern + * "sm_", returns as an int. If the attribute is present but does not + * match that pattern, returns 0. + * + * Preconditions: the target must have an "arch" attribute (this is checked via + * ICHECK). + * + * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if + * the arch string does not match "sm_". + */ +static int GetArchInt(Target target) { + int arch_int = 0; + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); + std::string arch = s.value(); + if (arch.rfind("sm_", 0) == 0) { + arch_int = std::stoi(arch.substr(3)); + } else { + arch_int = 0; + } + return arch_int; +} + +/** + * @brief Lower the GEMM operator to a TL TIR call expression. + * + * Constructs a tl::gemm call string parameterized by M, N, K, warp partition, + * transpose flags, accumulation clearing, target-specific stride/offset/kPack + * and optional workgroup wait value, then returns an Evaluate(call) node + * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles. + * + * @param T Contains lowering context including thread bounds and target. + * @param analyzer Optional arithmetic analyzer used by lowering (may be + * nullptr). + * @return Stmt A TIR statement representing the evaluated TL GEMM call. + */ +Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + auto block_size = *as_const_int(T.thread_bounds->extent); + GemmInst gemm_inst = getGemmInst(block_size, T.target); + auto [warp_m, warp_n] = + policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); + + // Build access pointers from regions locally + PrimExpr Aptr = + MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Bptr = + MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Cptr = + MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true); + + std::stringstream ss; + std::string op_name; + + if (gemm_inst == GemmInst::kTCGEN5MMA) { + auto [can_use_tcgen5mma, meta] = + GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype); + ICHECK(can_use_tcgen5mma); + ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared"); + ICHECK(c_.scope() == "shared.tmem"); + ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA"; + if (a_.scope() == "shared.tmem") { + op_name = "tl::tcgen5mma_gemm_ts"; + } else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") { + op_name = "tl::tcgen5mma_gemm_ss"; + } else { + ICHECK(0) + << "Unsupported A scope for TCGEN5MMA: " + << a_.scope(); // If this is triggered, it means Tilelang has bugs. + } + ICHECK(wgWait_ == -1) + << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please " + "use " + "wg_wait = -1 and manually synchronize with mbarrier."; + + std::string accum_dtype = ""; + if (c_->dtype.is_float()) { + if (c_->dtype.bits() == 32) { + accum_dtype = "float"; + } + } + ICHECK(!accum_dtype.empty()) + << "Unsupported C dtype for TCGEN5MMA: " << c_->dtype; + ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", "; + ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", "; + ss << transA_ << ", " << transB_ << ", "; + ss << accum_dtype; + ss << ">"; + + auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; + Array new_args; + auto mbarPtr = + MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true); + new_args.push_back(StringImm(ss.str())); + new_args.push_back(Aptr); + new_args.push_back(Bptr); + new_args.push_back(BufferLoad(C_buffer, cCoords_)); + new_args.push_back(mbarPtr); + new_args.push_back(clearAccum_); + auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); + + // Since TCGEN5MMA atoms provided by CUTLASS always have an internal + // `elect_one_sync()`, we check if we are calling it using full warps + constexpr int warp_size = 32; + ICHECK( + analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) && + analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size), + 0)) + << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) " + "and aligned to warps."; + if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) { + // If the thread bounds is exactly one warp, we can use the original call + return Evaluate(new_call); + } else { + // Add an if-else clause + auto tcgen5mma_call = + IfThenElse(EQ(FloorDiv(T.thread_var, warp_size), + FloorDiv(T.thread_bounds->min, warp_size)), + Evaluate(new_call)); + return tcgen5mma_call; + } + } + + if (a_.scope() == "local.fragment") { + ICHECK(b_.scope() != "local.fragment"); + ICHECK(!transA_) + << "gemm_rs requires the A operand to be in non-transposed layout."; + op_name = "tl::gemm_rs"; + } else if (b_.scope() == "local.fragment") { + op_name = "tl::gemm_sr"; + } else { + op_name = "tl::gemm_ss"; + } + ICHECK(c_.scope() == "local.fragment"); + + ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", "; + ss << warp_m << ", " << warp_n << ", "; + ss << transA_ << ", " << transB_; + auto clear_accum_bool = clearAccum_.as(); + ICHECK(clear_accum_bool.has_value()) + << "clear_accum must be a constant Bool type, got " << clearAccum_; + ss << ", " << bool(clear_accum_bool.value()); + if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) { + ss << ", " << strideA_ << ", " << strideB_; + ss << ", " << offsetA_ << ", " << offsetB_; + } + if (TargetIsCDNA(T.target)) { + // for cdna gemm, we need to specify kPack + ss << ", " << kPack_; + } else if (TargetIsHopper(T.target)) { + ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false"); + } + + // Emit wg_wait if necessary + if (TargetIsHopper(T.target)) { + if (wgWait_ != 0) { + ss << ", " << wgWait_; + } + } else if (TargetIsSm100(T.target)) { + // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction + // but all threads need to wait, so we emit another statement for cases + // where wg_wait == 0. + ICHECK(wgWait_ == 0 || wgWait_ == -1) + << "wg_wait must be 0 or -1 for Sm100"; + } else { + ICHECK(wgWait_ == 0) + << "wg_wait must be 0 for non-Hopper and non-Sm100 targets"; + } + ss << ">"; + + auto new_call = Call(DataType::Handle(), tl::tl_gemm(), + Array{StringImm(ss.str()), Aptr, Bptr, Cptr}); + return Evaluate(new_call); +} + +/** + * @brief Infer and bind target-specific memory/layout mappings for A, B, and C. + * + * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM + * operator according to the target architecture, thread bounds, warp + * partitioning, data types, and transpose flags, then binds fragment layouts + * to the thread range when required. + * + * Preconditions: + * - C.scope() == "local.fragment" + * + * Side effects: + * - Marks layout inference as completed (sets completed_ = true). + * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or + * incompatible shape constraints. + * + * @param T Input layout-inference context (provides thread bounds and target). + * @return LayoutMap mapping A, B, and C to their inferred layouts. + */ +LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (completed_) + return {}; + LayoutMap results; + auto thread_range = T.thread_bounds; + auto block_size = *as_const_int(thread_range->extent); + GemmInst gemm_inst = getGemmInst(block_size, T.target); + auto [warp_m, warp_n] = + policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); + if (TargetIsVolta(T.target)) { + ICHECK(c_.scope() == "local.fragment") + << "Volta gemm only supports C in local.fragment scope, got " + << c_.scope(); + auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, + c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { + int dim_A = a_->shape.size(); + results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]), + *as_const_int(a_->shape[dim_A - 1]), + true, !transA_)); + } else if (a_.scope() == "local.fragment") { + ICHECK(transA_ == false); + auto fragment = + makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n); + results.Set(a_, fragment->BindThreadRange(thread_range)); + } else { + ICHECK(0); + } + + ICHECK(b_.scope() == "shared" || b_.scope() == "shared.dyn"); + int dim_B = b_->shape.size(); + results.Set(b_, makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]), + *as_const_int(b_->shape[dim_B - 1]), + false, transB_)); + } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || + TargetIsSM120(T.target) || + (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) { + ICHECK(c_.scope() == "local.fragment") + << "MMA only supports C in local.fragment scope, got " << c_.scope(); + + auto fragment = + makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + + if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { + int dim_A = a_->shape.size(); + const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); + results.Set(a_, + makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, + a_->dtype.bits(), !transA_)); + } else if (a_.scope() == "local.fragment") { + auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n, + a_->dtype.bits(), transA_); + results.Set(a_, fragment->BindThreadRange(thread_range)); + } else { + ICHECK(0); + } + if (b_.scope() == "shared" || b_.scope() == "shared.dyn") { + int dim_B = b_->shape.size(); + const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); + results.Set(b_, + makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, + b_->dtype.bits(), transB_)); + } else if (b_.scope() == "local.fragment") { + auto fragment = + makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); + results.Set(b_, fragment->BindThreadRange(thread_range)); + } else { + ICHECK(0); + } + } else if (TargetIsHopper(T.target)) { + ICHECK(c_.scope() == "local.fragment") + << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ") + << "only supports C in local.fragment scope, got " << c_.scope(); + auto fragment = gemm_inst == GemmInst::kWGMMA + ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m, + n_ / warp_n, c_->dtype.bits()) + : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, + c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { + int dim_A = a_->shape.size(); + const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); + const int64_t continuity = + transA_ ? 4 * mat_continuous / warp_m : mat_continuous; + auto ABLayout = + gemm_inst == GemmInst::kWGMMA + ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, + a_->dtype.bits(), !transA_) + : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, + a_->dtype.bits(), !transA_); + results.Set(a_, ABLayout); + } else { + auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n, + a_->dtype.bits(), transA_); + results.Set(a_, fragment->BindThreadRange(thread_range)); + } + if (b_.scope() == "shared" || b_.scope() == "shared.dyn") { + int dim_B = b_->shape.size(); + const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); + const int64_t continuity = + transB_ ? mat_continuous : mat_continuous / warp_n; + + auto ABLayout = + gemm_inst == GemmInst::kWGMMA + ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, + b_->dtype.bits(), transB_) + : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, + b_->dtype.bits(), transB_); + results.Set(b_, ABLayout); + } else { + auto fragment = + makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); + results.Set(b_, fragment->BindThreadRange(thread_range)); + } + } else if (gemm_inst == GemmInst::kTCGEN5MMA) { + ICHECK(c_.scope() == "shared.tmem") + << "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope(); + ICHECK(a_.scope() == "shared.dyn" || a_.scope() == "shared") + << "Current TCGEN5MMA only supports A in shared.dyn scope"; + auto [can_use_tcgen5mma, meta] = + GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype); + ICHECK(can_use_tcgen5mma); + { + int dim_A = a_->shape.size(); + const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); + results.Set(a_, makeGemmABLayoutSm100(mat_stride, mat_continuous, + mat_continuous, a_->dtype.bits(), + transA_ ? 1 : 2)); + } + { + int dim_B = b_->shape.size(); + const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); + const int64_t continuity = mat_continuous; + results.Set(b_, + makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, + b_->dtype.bits(), transB_ ? 2 : 1)); + } + { + Layout res; + IterVar i = make_itervar("i", m_); + IterVar j = make_itervar("j", n_); + ICHECK(m_ % meta.atom_m == 0); + PrimExpr atom_idx = FloorDiv(i, meta.atom_m) + + FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m); + PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i" + PrimExpr aj = FloorMod(j, meta.atom_n); + if (meta.atom_m == 128) { + // Layout D + // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d) + res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n}); + } else if (meta.atom_m == 64) { + // Layout E + // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e) + // since .ws variant is used About why we use .ws variant here, please + // refer to gemm_sm100.h + res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) + + FloorDiv(aj, meta.atom_n / 2) * 64, + FloorMod(aj, meta.atom_n / 2) + + atom_idx * (meta.atom_n / 2)}); + } else if (meta.atom_m == 32) { + // Layout G + // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g) + res = Layout( + Array{i, j}, + {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32, + FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)}); + } else { + ICHECK(0); + } + results.Set(c_, res); + } + } else if (TargetIsCDNA(T.target)) { + ICHECK(c_.scope() == "local.fragment") + << "CDNA gemm (FMMA) only supports C in local.fragment scope, got " + << c_.scope(); + if (TargetIsDCU(T.target)) { + auto fragment = makeGemmFragmentCDCU(m_, n_, m_ / warp_m, n_ / warp_n, + c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + } else { + auto fragment = makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n, + c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + } + if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { + int dim_A = a_->shape.size(); + auto shared_layout = makeGemmABLayoutCDNA( + *as_const_int(a_->shape[dim_A - 2]), + *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_); + results.Set(a_, shared_layout); + } else if (a_.scope() == "local.fragment") { + auto fragment = + makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n, + a_->dtype.bits(), kPack_, transA_); + results.Set(a_, fragment->BindThreadRange(thread_range)); + } else { + ICHECK(0); + } + if (b_.scope() == "shared" || b_.scope() == "shared.dyn") { + int dim_B = b_->shape.size(); + auto shared_layout = makeGemmABLayoutCDNA( + *as_const_int(b_->shape[dim_B - 2]), + *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_); + + results.Set(b_, shared_layout); + } else if (b_.scope() == "local.fragment") { + auto fragment = + makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); + results.Set(b_, fragment->BindThreadRange(thread_range)); + } else { + ICHECK(0); + } + } else { + ICHECK(0) << "Not supported " << T.target->str(); + } + completed_ = true; + return results; +} + +TIR_REGISTER_TL_TILE_OP(Gemm, gemm) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tl.GemmWarpPolicy") + .set_attr("TScriptPrinterName", "GemmWarpPolicy"); + +TVM_FFI_STATIC_INIT_BLOCK() { + GemmNode::RegisterReflection(); + GemmWarpPolicyNode::RegisterReflection(); + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition", + [](GemmWarpPolicy policy, int M, int N, int block_size, + Target target, GemmInst gemm_inst) { + policy->computeWarpPartition(M, N, block_size, target, + gemm_inst); + }); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/gemm.h b/tilelang/original/src/op/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..1121e3eb761ea875843e86dbe8db5e7c7b2853a1 --- /dev/null +++ b/tilelang/original/src/op/gemm.h @@ -0,0 +1,154 @@ +/*! + * \file tl/op/gemm.h + * \brief Define gemm operator. + * + */ + +#ifndef TVM_TL_OP_GEMM_H_ +#define TVM_TL_OP_GEMM_H_ + +#include "operator.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + +enum class GemmWarpPolicyType : uint8_t { + kSquare = 0, + kFullRow = 1, + kFullCol = 2, + kFree = 3, +}; + +// Target GEMM instruction +enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, KMMAC }; +class GemmWarpPolicyNode : public Object { +public: + mutable int m_warp{0}; + mutable int n_warp{0}; + int policy_type; + + TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmWarpPolicy", GemmWarpPolicyNode, Object); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("policy_type", &GemmWarpPolicyNode::policy_type) + .def_ro("m_warp", &GemmWarpPolicyNode::m_warp) + .def_ro("n_warp", &GemmWarpPolicyNode::n_warp); + } + + std::pair computeWarpPartition(int M, int N, int block_size, + Target target, + GemmInst gemm_inst) const; + + bool isSquare() const { + return policy_type == int(GemmWarpPolicyType::kSquare); + } + bool isFullRow() const { + return policy_type == int(GemmWarpPolicyType::kFullRow); + } + bool isFullCol() const { + return policy_type == int(GemmWarpPolicyType::kFullCol); + } + bool isFree() const { return policy_type == int(GemmWarpPolicyType::kFree); } +}; + +class GemmWarpPolicy : public ObjectRef { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmWarpPolicy, ObjectRef, + GemmWarpPolicyNode); + + explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) { + auto node = tvm::ffi::make_object(); + node->policy_type = (int)policy_type; + data_ = std::move(node); + } + + explicit GemmWarpPolicy(int policy_type) { + auto node = tvm::ffi::make_object(); + node->policy_type = policy_type; + data_ = std::move(node); + } + + explicit GemmWarpPolicy(int m_warp, int n_warp) { + auto node = tvm::ffi::make_object(); + node->m_warp = m_warp; + node->n_warp = n_warp; + node->policy_type = (int)GemmWarpPolicyType::kFree; + data_ = std::move(node); + } +}; + +class GemmNode : public TileOperatorNode { +public: + bool checkWgmma() const; + tir::Buffer a_, b_, c_; + // BufferRegion for A, B and C + BufferRegion aRegion_, bRegion_, cRegion_; + bool transA_, transB_; + int m_, n_, k_; + int strideA_, strideB_; + int offsetA_, offsetB_; + PrimExpr clearAccum_ = const_false(); + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack_ = 1; + int wgWait_ = 0; + BufferRegion mbarRegion_; + std::optional mbar_; // mbar is optional, only used for TCGEN5MMA + Array cCoords_; + mutable GemmWarpPolicy policy_; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("a", &GemmNode::a_) + .def_ro("b", &GemmNode::b_) + .def_ro("c", &GemmNode::c_) + .def_ro("aRegion", &GemmNode::aRegion_) + .def_ro("bRegion", &GemmNode::bRegion_) + .def_ro("cRegion", &GemmNode::cRegion_) + .def_ro("transA", &GemmNode::transA_) + .def_ro("transB", &GemmNode::transB_) + .def_ro("m", &GemmNode::m_) + .def_ro("n", &GemmNode::n_) + .def_ro("k", &GemmNode::k_) + .def_ro("strideA", &GemmNode::strideA_) + .def_ro("strideB", &GemmNode::strideB_) + .def_ro("offsetA", &GemmNode::offsetA_) + .def_ro("offsetB", &GemmNode::offsetB_) + .def_ro("clearAccum", &GemmNode::clearAccum_) + .def_ro("kPack", &GemmNode::kPack_) + .def_ro("wgWait", &GemmNode::wgWait_) + .def_ro("policy", &GemmNode::policy_); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + +private: + GemmInst getGemmInst(int block_size, Target target) const; + bool allowTcgen5Mma(Target target) const; + bool allowWgmma(int block_size, Target target) const; + + mutable bool completed_ = false; +}; + +class Gemm : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode); + TVM_DLL Gemm(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_H_ diff --git a/tilelang/original/src/op/gemm_py.cc b/tilelang/original/src/op/gemm_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..f73922afad945c16cf08243b3313a3aac0cf77e2 --- /dev/null +++ b/tilelang/original/src/op/gemm_py.cc @@ -0,0 +1,355 @@ +/*! + * \file tl/op/gemm_py.cc + * \brief Implementation of General Matrix Multiplication (GEMM) operators + */ + +#include "gemm_py.h" + +#include "builtin.h" +#include +#include +#include +#include + +#include "../target/utils.h" +#include "tcgen5_meta.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} + +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} + +/** + * @brief Construct a Gemm operator from serialized TL arguments and a buffer + * map. + * + * This constructor deserializes operator parameters from `args` and resolves + * buffer references via `vmap`, populating an internal GemmPyNode with: + * - device pointers for A, B, C and their corresponding Buffer objects, + * - transpose flags for A and B, + * - matrix dimensions M, N, K, + * - warp allocation policy and clear_accum flag, + * - strides and memory offsets for A and B, + * - optional kPack (must be 1 or 2) and optional wg_wait. + * + * The populated GemmPyNode is stored into the wrapper's internal `data_`. + * + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] + * + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here. + */ +GemmPy::GemmPy(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->bRegion_ = NormalizeToBufferRegion(args[1]); + node->cRegion_ = NormalizeToBufferRegion(args[2]); + + node->a_ = node->aRegion_->buffer; + node->b_ = node->bRegion_->buffer; + node->c_ = node->cRegion_->buffer; + node->transA_ = args[3].as().value(); + node->transB_ = args[4].as().value(); + node->m_ = args[5].as().value()->value; + node->n_ = args[6].as().value()->value; + node->k_ = args[7].as().value()->value; + node->policy_ = GemmWarpPolicy(args[8].as().value()->value); + node->clearAccum_ = args[9].as().value(); + node->strideA_ = args[10].as().value()->value; + node->strideB_ = args[11].as().value()->value; + node->offsetA_ = args[12].as().value()->value; + node->offsetB_ = args[13].as().value()->value; + if (args.size() > 14) { + node->kPack_ = args[14].as().value()->value; + if (node->kPack_ != 1 && node->kPack_ != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 15) { + node->wgWait_ = args[15].as().value()->value; + } + if (args.size() > 16) { + if (const auto *load = args[16].as()) { + node->mbarRegion_ = + NormalizeToBufferRegion(Downcast(args[16])); + node->mbar_ = node->mbarRegion_->buffer; + } + } + node->cCoords_ = Array( + {args[17].as().value(), args[18].as().value()}); + data_ = std::move(node); +} + +/** + * @brief Create a copy of this GemmPyNode as a TileOperator. + * + * Constructs a new GemmPyNode by copying the current node state and returns it + * wrapped in a Gemm TileOperator. + * + * @return TileOperator A Gemm operator that owns a copy of this node. + */ +TileOperator GemmPyNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return GemmPy(op); +} + +bool GemmPyNode::allowTcgen5Mma(Target target) const { + return TargetIsSm100(target) && + ((a_.scope() == "shared.dyn" || a_.scope() == "shared" || + a_.scope() == "shared.tmem") && + (b_.scope() == "shared.dyn" || b_.scope() == "shared") && + c_.scope() == "shared.tmem") && + GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first; +} + +bool GemmPyNode::allowWgmma(int block_size, Target target) const { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + + int warp_size = TargetGetWarpSize(target); + int num_warps = block_size / warp_size; + return !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && + TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) && + checkWgmma(); +} + +GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const { + bool allow_tcgen5mma = allowTcgen5Mma(target); + bool allow_wgmma = allowWgmma(block_size, target); + if (allow_tcgen5mma) { + return GemmInst::kTCGEN5MMA; + } else if (allow_wgmma) { + return GemmInst::kWGMMA; + } else if(TargetIsDCU(target)) { + return GemmInst::KMMAC; + } else if (TargetIsCDNA(target)) { + return GemmInst::kMFMA; + } else if (TargetIsVolta(target) || TargetIsAmpere(target) || + TargetIsTuring(target) || TargetIsHopper(target) || + TargetIsSm100(target) || TargetIsSM120(target)) { + return GemmInst::kMMA; + } else { + ICHECK(0) << "Unsupported target for gemm: " << target->str(); + return GemmInst::kMMA; // This line will never be reached due to ICHECK, but + // satisfies compiler + } +} + +/** + * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. + * + * Evaluates device-memory placement, data-type combinations, transpose flags, + * and K divisibility constraints required for the Hopper WGMMA code path. + * + * The check returns true only when: + * - B resides in shared memory ("shared" or "shared.dyn"); and + * - (C, A, B) dtypes match one of the supported combinations below and K + * satisfies the required alignment; and + * - for combinations that require specific orientations, A is not transposed + * and B is transposed. + * + * Supported combinations and constraints: + * - C=float16: + * - A=float16, B=float16: K % 16 == 0 + * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % + * 32 == 0 + * - C=float32: + * - A=float16, B=float16: K % 16 == 0 + * - A=bfloat16, B=bfloat16: K % 16 == 0 + * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 + * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 + * - C=int32: + * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) + * and K % 32 == 0 + * + * @return true if WGMMA is supported for the current buffers, dtypes, and + * transpose/shape constraints; false otherwise. + */ +bool GemmPyNode::checkWgmma() const { + if (b_.scope() != "shared.dyn" && b_.scope() != "shared") { + return false; + } + + if (c_->dtype == DataType::Float(16)) { + if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) + return k_ % 16 == 0; + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) + return (!transA_) && transB_ && k_ % 32 == 0; + else + return false; + } else if (c_->dtype == DataType::Float(32)) { + if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) + return k_ % 16 == 0; + else if (a_->dtype == DataType::BFloat(16) && + b_->dtype == DataType::BFloat(16)) + return k_ % 16 == 0; + else if (a_->dtype == DataType::Float(32) && + b_->dtype == DataType::Float(32)) + return (!transA_) && transB_ && k_ % 8 == 0; + else if (a_->dtype.is_float8() && b_->dtype.is_float8()) + return (!transA_) && transB_ && k_ % 32 == 0; + else + return false; + } else if (c_->dtype == DataType::Int(32)) { + if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8)) + return (!transA_) && transB_ && k_ % 32 == 0; + else + return false; + } else { + return false; + } +} + +/** + * @brief Parse and return the numeric GPU architecture from a Target's "arch" + * attribute. + * + * Examines the target's "arch" string and, if it matches the pattern + * "sm_", returns as an int. If the attribute is present but does not + * match that pattern, returns 0. + * + * Preconditions: the target must have an "arch" attribute (this is checked via + * ICHECK). + * + * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if + * the arch string does not match "sm_". + */ +static int GetArchInt(Target target) { + int arch_int = 0; + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); + std::string arch = s.value(); + if (arch.rfind("sm_", 0) == 0) { + arch_int = std::stoi(arch.substr(3)); + } else { + arch_int = 0; + } + return arch_int; +} + +Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + auto block_size = *as_const_int(T.thread_bounds->extent); + GemmInst gemm_inst = getGemmInst(block_size, T.target); + + auto [warp_m, warp_n] = + policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { + auto prim_func = + Downcast((*f)(tvm::ffi::GetRef(this), T.layout_map, + T.target, T.thread_bounds, T.thread_var)); + ICHECK(prim_func->attrs.defined()); + auto global_symbol = + prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.has_value()); + if (prim_func->body.as()) { + BlockRealize block_realize = Downcast(prim_func->body); + auto block = block_realize->block; + { + BlockNode *n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, + block); + } + // warp with block realize node + return BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body)); + } else { + LOG(FATAL) << "No lower function found for gemm_py"; + return Stmt(); // This line will never be reached due to LOG(FATAL), but + // satisfies compiler + } +} + +LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (completed_) + return {}; + LayoutMap results; + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { + results = Downcast( + (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); + // Bind all fragment layouts with the provided thread range + for (auto kv : results) { + const Buffer &buf = kv.first; + const Layout &layout = kv.second; + if (auto frag = layout.as()) { + results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds)); + } + } + } else { + LOG(FATAL) << "No infer layout function found for gemm_py"; + } + + completed_ = true; + return results; +} + +TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); } + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.GemmPyGemmInst", + [](GemmPy gemm_py, int block_size, Target target) { + return gemm_py->getGemmInst(block_size, target); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tl.get_tcgen5_mma_meta", + [](int M, int N, int K, DataType ab_dtype, DataType c_dtype) { + auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype); + Array result; + if (success) { + result.push_back(Integer(meta.atom_m)); + result.push_back(Integer(meta.atom_n)); + result.push_back(Integer(meta.atom_k)); + result.push_back(Integer(meta.enable_ws)); + result.push_back(Integer(meta.enable_2cta)); + } + return result; + }); + refl::GlobalDef().def( + "tl.get_tcgen5_instr_desc", + [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, + DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a, + int scale_in_b) { + uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, + c_dtype, a_is_k_major, b_is_k_major, + scale_in_a, scale_in_b); + return Integer(static_cast(desc)); + }); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/gemm_py.h b/tilelang/original/src/op/gemm_py.h new file mode 100644 index 0000000000000000000000000000000000000000..2fe47be881ff963cabf3d069ba9457dd2b5330fd --- /dev/null +++ b/tilelang/original/src/op/gemm_py.h @@ -0,0 +1,93 @@ +/*! + * \file tl/op/gemm_py.h + * \brief Define gemm operator. + * + */ + +#ifndef TVM_TL_OP_GEMM_PY_H_ +#define TVM_TL_OP_GEMM_PY_H_ + +#include "gemm.h" +#include "operator.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + +class GemmPyNode : public TileOperatorNode { +public: + bool checkWgmma() const; + bool allowTcgen5Mma(Target target) const; + bool allowWgmma(int block_size, Target target) const; + tir::Buffer a_, b_, c_; + // BufferRegion for A, B and C + BufferRegion aRegion_, bRegion_, cRegion_; + bool transA_, transB_; + int m_, n_, k_; + int strideA_, strideB_; + int offsetA_, offsetB_; + PrimExpr clearAccum_ = const_false(); + BufferRegion mbarRegion_; + tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA + Array cCoords_; + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack_ = 1; + int wgWait_ = 0; + mutable GemmWarpPolicy policy_; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("a", &GemmPyNode::a_) + .def_ro("b", &GemmPyNode::b_) + .def_ro("c", &GemmPyNode::c_) + .def_ro("aRegion", &GemmPyNode::aRegion_) + .def_ro("bRegion", &GemmPyNode::bRegion_) + .def_ro("cRegion", &GemmPyNode::cRegion_) + .def_ro("transA", &GemmPyNode::transA_) + .def_ro("transB", &GemmPyNode::transB_) + .def_ro("m", &GemmPyNode::m_) + .def_ro("n", &GemmPyNode::n_) + .def_ro("k", &GemmPyNode::k_) + .def_ro("strideA", &GemmPyNode::strideA_) + .def_ro("strideB", &GemmPyNode::strideB_) + .def_ro("offsetA", &GemmPyNode::offsetA_) + .def_ro("offsetB", &GemmPyNode::offsetB_) + .def_ro("clearAccum", &GemmPyNode::clearAccum_) + .def_ro("mbarRegion", &GemmPyNode::mbarRegion_) + .def_ro("mbar", &GemmPyNode::mbar_) + .def_ro("cCoords", &GemmPyNode::cCoords_) + .def_ro("kPack", &GemmPyNode::kPack_) + .def_ro("wgWait", &GemmPyNode::wgWait_) + .def_ro("policy", &GemmPyNode::policy_); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + + // Target GEMM instruction + GemmInst getGemmInst(int block_size, Target target) const; + +private: + mutable bool completed_ = false; +}; + +class GemmPy : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); + TVM_DLL GemmPy(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_PY_H_ diff --git a/tilelang/original/src/op/gemm_sp.cc b/tilelang/original/src/op/gemm_sp.cc new file mode 100644 index 0000000000000000000000000000000000000000..4c0ae08b9ef62e8801f9e47b2979ff9bafe31007 --- /dev/null +++ b/tilelang/original/src/op/gemm_sp.cc @@ -0,0 +1,326 @@ +/*! + * \file tl/op/gemm_sp.cc + * + * Define gemm_sp operator. + */ + +#include "gemm_sp.h" + +#include +#include +#include +#include + +#include "../target/utils.h" +#include "builtin.h" +#include "gemm.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +std::pair GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, + int block_size, + Target target, + bool use_wgmma, + int bits) const { + int num_warps = block_size / TargetGetWarpSize(target); + + auto [m_warp, n_warp] = GemmWarpPolicyNode::computeWarpPartition( + M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA); + + // Special handling for gemm_sp when the tiling size is not a multiple + // This should be consistent with shape check in gemm_sp_sm80.h + int m_atom_size = bits == 16 ? 32 : 16; + int n_atom_size = bits == 16 ? 32 : 16; + static const char *err_msg = + "Cannot arrange the warp shape to be a multiple of atom size, please " + "reduce num threads or increase tiling size"; + if (TargetIsAmpere(target)) { + int warp_shape_m = M / m_warp; + int warp_shape_n = N / n_warp; + if (warp_shape_m % m_atom_size) { // GemmWarpPolicy::kFullRow + m_warp = M / m_atom_size; + ICHECK(m_warp > 0) << err_msg; + n_warp = num_warps / m_warp; + warp_shape_n = N / n_warp; + ICHECK(warp_shape_n % n_atom_size == 0) << err_msg; + } else if (warp_shape_n % n_atom_size != 0) { // GemmWarpPolicy::kFullColumn + n_warp = N / n_atom_size; + ICHECK(n_warp > 0) << err_msg; + m_warp = num_warps / n_warp; + warp_shape_m = M / m_warp; + ICHECK(warp_shape_m % m_atom_size == 0) << err_msg; + } + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps, please report an issue when " + "encounter this" + << ", m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps" + << num_warps; + this->m_warp = m_warp; + this->n_warp = n_warp; + } + return {m_warp, n_warp}; +} + +/** + * @brief Construct a GemmSP operator node from TL call arguments and a buffer + * map. + * + * Parses the expected call argument tuple and fills an internal GemmSPNode: + * - Buffers: A (args[0]), E (args[1]), B (args[2]), C (args[3]) are looked up + * in vmap. + * - Booleans: trans_A (args[4]), trans_B (args[5]). + * - Dimensions: M (args[6]), N (args[7]), K (args[8]) as integers. + * - Warp policy: policy (args[9]) mapped to GemmWarpPolicy. + * - clear_accum: boolean flag (args[10]). + * - Optional kPack (args[11]): must be 1 or 2 (checked via ICHECK). + * - Optional wg_wait (args[12]): integer workgroup wait parameter. + * + * The populated GemmSPNode is stored in the instance's internal data_ pointer. + * + * @param args Positional TL call arguments in the above order. + * + * @note An ICHECK failure is raised if a provided kPack is not 1 or 2. + */ +GemmSP::GemmSP(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->eRegion_ = NormalizeToBufferRegion(args[1]); + node->bRegion_ = NormalizeToBufferRegion(args[2]); + node->cRegion_ = NormalizeToBufferRegion(args[3]); + node->a_ = node->aRegion_->buffer; + node->e_ = node->eRegion_->buffer; + node->b_ = node->bRegion_->buffer; + node->c_ = node->cRegion_->buffer; + node->transA_ = args[4].as().value(); + node->transB_ = args[5].as().value(); + node->m_ = args[6].as().value()->value; + node->n_ = args[7].as().value()->value; + node->k_ = args[8].as().value()->value; + node->policy_ = GemmSPWarpPolicy(args[9].as().value()->value); + node->clearAccum_ = args[10].as().value(); + if (args.size() > 11) { + node->kPack_ = args[11].as().value()->value; + if (node->kPack_ != 1 && node->kPack_ != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 12) { + node->wgWait_ = args[12].as().value()->value; + } + data_ = std::move(node); +} + +/** + * @brief Create a deep copy of this GemmSPNode wrapped as a TileOperator. + * + * Returns a new TileOperator that owns a copy of this node. The cloned node + * duplicates all fields of the original; subsequent modifications to the + * clone do not affect the original node. + * + * @return TileOperator A TileOperator holding a cloned GemmSPNode. + */ +TileOperator GemmSPNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return GemmSP(op); +} + +/** + * @brief Lower this GemmSP node to a TL (tensile-like) intrinsic call. + * + * Constructs and returns an Evaluate statement containing a call to the + * TL gemm_sp intrinsic that encodes this GEMM's template parameters + * (M, N, K, warp partition, transposition flags, clear_accum, and optional + * Hopper/WGMMA and wg_wait modifiers) and the remapped buffer access pointers. + * + * The function validates that A, B, and E reside in shared (or shared.dyn) + * memory (ICHECK failures otherwise), computes the warp partition based on + * the launch configuration and target, and emits a single tl::tl_gemm_sp call + * with a string template describing the configuration. + * + * @param T Lowering context containing thread bounds, target, and optional + * buffer remapping used to obtain the final buffer AccessPtr + * arguments for the TL call. + * @return Stmt An Evaluate wrapping the constructed tl::tl_gemm_sp call. + */ +Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + int warp_size = 32; + + auto block_size = *as_const_int(T.thread_bounds->extent); + bool maybe_wgmma = TargetIsHopper(T.target) && (this->m_ >= 64) && + (block_size / warp_size % 4 == 0); + + auto [warp_m, warp_n] = policy_->computeWarpPartition( + m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits()); + + std::stringstream ss; + std::string op_name = "tl::gemm_sp_ss"; + ICHECK((a_.scope() == "shared" || a_.scope() == "shared.dyn") && + (b_.scope() == "shared" || b_.scope() == "shared.dyn")) + << "Only support shared.dyn scope for A and B, but received " + << a_.scope() << " and " << b_.scope(); + ICHECK((e_.scope() == "shared" || e_.scope() == "shared.dyn")) + << "Only support shared.dyn scope for E as copy from smem to rmem are " + "delegated to cute implementation, found " + << e_.scope(); + ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", "; + ss << warp_m << ", " << warp_n << ", "; + ss << transA_ << ", " << transB_; + ss << ", " << clearAccum_; + if (TargetIsHopper(T.target)) { + ss << ", " << (maybe_wgmma ? "true" : "false"); + } + if (wgWait_ != 0) { + ss << ", " << wgWait_; + } + ss << ">"; + auto A_buffer = T.buffer_remap.count(a_) ? T.buffer_remap[a_] : a_; + auto B_buffer = T.buffer_remap.count(b_) ? T.buffer_remap[b_] : b_; + auto C_buffer = T.buffer_remap[c_]; + auto E_buffer = T.buffer_remap.count(e_) ? T.buffer_remap[e_] : e_; + + auto new_call = + Call(DataType::Handle(), tl::tl_gemm_sp(), + Array{StringImm(ss.str()), A_buffer.access_ptr(1), + B_buffer.access_ptr(1), C_buffer.access_ptr(3), + E_buffer.access_ptr(1)}); + return Evaluate(new_call); +} + +/** + * @brief Infers and returns the memory/layout mapping for the GemmSP operator. + * + * Infers thread-local fragment layout for C and shared-memory layouts for A and + * B based on the target (Hopper-only path), block/thread bounds in T, + * transposition flags, and matrix dimensions stored in the node. The function + * caches its work: if layout inference has already completed (completed_ == + * true) it returns an empty LayoutMap. + * + * Precondition: + * - C.scope() must be "local.fragment". + * + * Behavior notes: + * - Only the Hopper target is supported; non-Hopper targets trigger a fatal + * check. + * - For Hopper, the function computes a warp partition from block size and may + * enable WGMMA-specific fragment creation when conditions on M and block size + * are met. + * - A and B must reside in "shared" or "shared.dyn"; otherwise the function + * aborts with a check failure. + * - The method sets completed_ = true before returning to avoid re-entrance. + * + * @param T LayoutInferArgs containing thread bounds and the target (used to + * select Hopper-specific layouts). + * @param level Currently unused inference detail level. + * @return LayoutMap mapping A, B, and C to their inferred layouts (or empty if + * inference was already completed). + */ +LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (completed_) + return {}; + LayoutMap results; + ICHECK(c_.scope() == "local.fragment"); + auto thread_range = T.thread_bounds; + auto block_size = *as_const_int(thread_range->extent); + if (TargetIsHopper(T.target)) { + const int warp_size = 32; + constexpr int wgmma_m = 16 * 4; + bool maybe_wgmma = + (this->m_ >= wgmma_m) && (block_size / warp_size % 4 == 0); + auto [warp_m, warp_n] = policy_->computeWarpPartition( + m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits()); + auto fragment = maybe_wgmma + ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m, + n_ / warp_n, c_->dtype.bits()) + : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, + c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { + int dim_A = a_->shape.size(); + const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); + results.Set(a_, makeGemmABLayoutHopper(mat_stride, mat_continuous, + mat_continuous, a_->dtype.bits(), + transA_ ? 1 : 2)); + } else { + ICHECK(false) << "Not implemented"; + } + + if (b_.scope() == "shared" || b_.scope() == "shared.dyn") { + int dim_B = b_->shape.size(); + const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); + const int64_t continuity = + transB_ ? mat_continuous : mat_continuous / warp_n; + results.Set(b_, + makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, + b_->dtype.bits(), transB_ ? 2 : 1)); + } else { + ICHECK(false) << "WGMMA only support B in shared."; + } + } else if (TargetIsAmpere(T.target)) { + auto [warp_m, warp_n] = policy_->computeWarpPartition( + m_, n_, block_size, T.target, false, a_->dtype.bits()); + auto fragment = makeGemmSparseFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, + c_->dtype.bits()); + results.Set(c_, fragment->BindThreadRange(thread_range)); + + if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { + int dim_A = a_->shape.size(); + const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); + results.Set(a_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, + a_->dtype.bits())); + } else if (a_.scope() == "local.fragment") { + // auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, + // A->dtype.bits(), trans_A); + // results.Set(A, fragment->BindThreadRange(thread_range)); + ICHECK(false) << "Not Implemented"; + } else { + ICHECK(0); + } + if (b_.scope() == "shared" || b_.scope() == "shared.dyn") { + int dim_B = b_->shape.size(); + const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); + results.Set(b_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, + b_->dtype.bits())); + } else if (b_.scope() == "local.fragment") { + // auto fragment = + // makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); + // results.Set(B, fragment->BindThreadRange(thread_range)); + ICHECK(false) << "Not Implemented"; + } else { + ICHECK(0); + } + } else { + ICHECK(0) << "Architecture is not supported: " << T.target->str(); + } + completed_ = true; + return results; +} + +TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tl.GemmSPWarpPolicy") + .set_attr("TScriptPrinterName", "GemmSPWarpPolicy"); + +TVM_FFI_STATIC_INIT_BLOCK() { + GemmSPNode::RegisterReflection(); + GemmSPWarpPolicyNode::RegisterReflection(); + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tl.GemmSPWarpPolicyComputeWarpPartition", + [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target, + bool use_wgmma, int bits) { + policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits); + return; + }); +} +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/gemm_sp.h b/tilelang/original/src/op/gemm_sp.h new file mode 100644 index 0000000000000000000000000000000000000000..a634e922fed37f98d75a3c541f3ad17e420407df --- /dev/null +++ b/tilelang/original/src/op/gemm_sp.h @@ -0,0 +1,119 @@ +/*! + * \file tl/op/gemm_sp.h + * \brief Define gemm_sp operator. + * + */ + +#ifndef TVM_TL_OP_GEMM_SP_H_ +#define TVM_TL_OP_GEMM_SP_H_ + +#include "gemm.h" +#include "operator.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + +class GemmSPWarpPolicyNode : public GemmWarpPolicyNode { +public: + std::pair computeWarpPartition(int M, int N, int block_size, + Target target, bool use_wgmma, + int bits) const; + TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, + GemmWarpPolicyNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type) + .def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp) + .def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp); + } +}; + +class GemmSPWarpPolicy : public ObjectRef { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPWarpPolicy, ObjectRef, + GemmSPWarpPolicyNode); + + explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) { + auto node = tvm::ffi::make_object(); + node->policy_type = (int)policy_type; + data_ = std::move(node); + } + + explicit GemmSPWarpPolicy(int policy_type) { + auto node = tvm::ffi::make_object(); + node->policy_type = policy_type; + data_ = std::move(node); + } + + explicit GemmSPWarpPolicy(int m_warp, int n_warp) { + auto node = tvm::ffi::make_object(); + node->m_warp = m_warp; + node->n_warp = n_warp; + node->policy_type = (int)GemmWarpPolicyType::kFree; + data_ = std::move(node); + } +}; + +class GemmSPNode : public TileOperatorNode { +public: + BufferRegion aRegion_, bRegion_, cRegion_, eRegion_; + tir::Buffer a_, b_, c_, e_; + bool transA_, transB_; + int m_, n_, k_; + bool clearAccum_ = false; + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack_ = 1; + int wgWait_ = 0; + + mutable GemmSPWarpPolicy policy_; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("policy", &GemmSPNode::policy_) + .def_ro("aRegion", &GemmSPNode::aRegion_) + .def_ro("bRegion", &GemmSPNode::bRegion_) + .def_ro("cRegion", &GemmSPNode::cRegion_) + .def_ro("eRegion", &GemmSPNode::eRegion_) + .def_ro("a", &GemmSPNode::a_) + .def_ro("b", &GemmSPNode::b_) + .def_ro("c", &GemmSPNode::c_) + .def_ro("e", &GemmSPNode::e_) + .def_ro("transA", &GemmSPNode::transA_) + .def_ro("transB", &GemmSPNode::transB_) + .def_ro("m", &GemmSPNode::m_) + .def_ro("n", &GemmSPNode::n_) + .def_ro("k", &GemmSPNode::k_) + .def_ro("clearAccum", &GemmSPNode::clearAccum_) + .def_ro("kPack", &GemmSPNode::kPack_) + .def_ro("wgWait", &GemmSPNode::wgWait_); + } + +private: + mutable bool completed_ = false; +}; + +class GemmSP : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode); + TVM_DLL GemmSP(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_SP_H_ diff --git a/tilelang/original/src/op/gemm_sp_py.cc b/tilelang/original/src/op/gemm_sp_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ad8ca9b5075eebad4cbd6db4d9a3209a207344c --- /dev/null +++ b/tilelang/original/src/op/gemm_sp_py.cc @@ -0,0 +1,289 @@ +/*! + * \file tl/op/gemm_sp_py.cc + * \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP) + * operators + */ + +#include "gemm_sp_py.h" +#include "utils.h" + +#include "builtin.h" +#include +#include +#include +#include + +#include "../target/utils.h" +#include "tvm/ffi/string.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a Gemm operator from serialized TL arguments and a buffer + * map. + * + * This constructor deserializes operator parameters from `args` and resolves + * buffer references via `vmap`, populating an internal GemmSPPyNode with: + * - device pointers for A, E, B, C and their corresponding Buffer objects, + * - transpose flags for A and B, + * - matrix dimensions M, N, K, + * - warp allocation policy and clear_accum flag, + * - strides and memory offsets for A and B, + * - optional kPack (must be 1 or 2) and optional wg_wait. + * + * The populated GemmSPPyNode is stored into the wrapper's internal `data_`. + * + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] + * @param vmap Mapping from access pointer vars to Buffer objects used to + * resolve the Buffer corresponding to each pointer argument. + * + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here. + */ +GemmSPPy::GemmSPPy(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->eRegion_ = NormalizeToBufferRegion(args[1]); + node->bRegion_ = NormalizeToBufferRegion(args[2]); + node->cRegion_ = NormalizeToBufferRegion(args[3]); + + node->A = node->aRegion_->buffer; + node->E = node->eRegion_->buffer; + node->B = node->bRegion_->buffer; + node->C = node->cRegion_->buffer; + + node->trans_A = args[4].as().value(); + node->trans_B = args[5].as().value(); + node->trans_E = args[6].as().value(); + node->M = args[7].as().value()->value; + node->N = args[8].as().value()->value; + node->K = args[9].as().value()->value; + node->policy = GemmWarpPolicy(args[10].as().value()->value); + node->clear_accum = args[11].as().value(); + node->stride_A = args[12].as().value()->value; + node->stride_B = args[13].as().value()->value; + node->offset_A = args[14].as().value()->value; + node->offset_B = args[15].as().value()->value; + if (args.size() > 16) { + node->kPack = args[16].as().value()->value; + if (node->kPack != 1 && node->kPack != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 17) { + node->wg_wait = args[17].as().value()->value; + } + data_ = std::move(node); +} + +/** + * @brief Create a copy of this GemmSPPyNode as a TileOperator. + * + * Constructs a new GemmSPPyNode by copying the current node state and returns + * it wrapped in a GemmSPPy TileOperator. + * + * @return TileOperator A GemmSPPy operator that owns a copy of this node. + */ +TileOperator GemmSPPyNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return GemmSPPy(op); +} + +GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const { + int warp_size = TargetGetWarpSize(target); + int num_warps = block_size / warp_size; + bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && + (num_warps % 4 == 0) && CheckWGMMA(); + if (allow_wgmma) { + return GemmInst::kWGMMA; + } else if (TargetIsCDNA(target)) { + return GemmInst::kMFMA; + } else if (TargetIsCuda(target)) { + return GemmInst::kMMA; + } else { + ICHECK(0) << "Unsupported target for gemm: " << target->str(); + } +} + +/** + * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. + * + * Evaluates device-memory placement, data-type combinations, transpose flags, + * and K divisibility constraints required for the Hopper WGMMA code path. + * + * The check returns true only when: + * - B resides in shared memory ("shared" or "shared.dyn"); and + * - (C, A, B) dtypes match one of the supported combinations below and K + * satisfies the required alignment; and + * - for combinations that require specific orientations, A is not transposed + * and B is transposed. + * + * Supported combinations and constraints: + * - C=float16: + * - A=float16, B=float16: K % 16 == 0 + * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % + * 32 == 0 + * - C=float32: + * - A=float16, B=float16: K % 16 == 0 + * - A=bfloat16, B=bfloat16: K % 16 == 0 + * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 + * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 + * - C=int32: + * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) + * and K % 32 == 0 + * + * @return true if WGMMA is supported for the current buffers, dtypes, and + * transpose/shape constraints; false otherwise. + */ +bool GemmSPPyNode::CheckWGMMA() const { + return false; // not supported yet + // if (B.scope() != "shared.dyn" && B.scope() != "shared") { + // return false; + // } + + // if (C->dtype == DataType::Float(16)) { + // if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) + // return K % 16 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else + // return false; + // } else if (C->dtype == DataType::Float(32)) { + // if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) + // return K % 16 == 0; + // else if (A->dtype == DataType::BFloat(16) && + // B->dtype == DataType::BFloat(16)) + // return K % 16 == 0; + // else if (A->dtype == DataType::Float(32) && B->dtype == + // DataType::Float(32)) + // return (!trans_A) && trans_B && K % 8 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) + // return (!trans_A) && trans_B && K % 32 == 0; + // else + // return false; + // } else if (C->dtype == DataType::Int(32)) { + // if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8)) + // return (!trans_A) && trans_B && K % 32 == 0; + // else + // return false; + // } else { + // return false; + // } +} + +/** + * @brief Parse and return the numeric GPU architecture from a Target's "arch" + * attribute. + * + * Examines the target's "arch" string and, if it matches the pattern + * "sm_", returns as an int. If the attribute is present but does not + * match that pattern, returns 0. + * + * Preconditions: the target must have an "arch" attribute (this is checked via + * ICHECK). + * + * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if + * the arch string does not match "sm_". + */ +static int GetArchInt(Target target) { + int arch_int = 0; + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); + std::string arch = s.value(); + if (arch.rfind("sm_", 0) == 0) { + arch_int = std::stoi(arch.substr(3)); + } else { + arch_int = 0; + } + return arch_int; +} + +Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + auto block_size = *as_const_int(T.thread_bounds->extent); + GemmInst gemm_inst = GetGemmInst(block_size, T.target); + + auto [warp_m, warp_n] = + policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst); + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) { + auto prim_func = + Downcast((*f)(tvm::ffi::GetRef(this), T.target, + T.thread_bounds, T.thread_var)); + ICHECK(prim_func->attrs.defined()); + auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.has_value()); + if (prim_func->body.as()) { + BlockRealize block_realize = Downcast(prim_func->body); + auto block = block_realize->block; + { + BlockNode *n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, + block); + } + // warp with block realize node + return BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body)); + } else { + LOG(FATAL) << "No lower function found for gemm_sp_py"; + } +} + +LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (completed_) + return {}; + LayoutMap results; + + if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) { + results = Downcast( + (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); + } else { + LOG(FATAL) << "No infer layout function found for gemm_sp_py"; + } + + completed_ = true; + return results; +} + +TIR_REGISTER_TL_TILE_OP(GemmSPPy, gemm_sp_py) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); } +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/gemm_sp_py.h b/tilelang/original/src/op/gemm_sp_py.h new file mode 100644 index 0000000000000000000000000000000000000000..2f79c5e156abc933fabb070a6cbbaa8ba1ac883e --- /dev/null +++ b/tilelang/original/src/op/gemm_sp_py.h @@ -0,0 +1,94 @@ +/*! + * \file tl/op/gemm_sp_py.h + * \brief Define gemm_sp_py operator. + * + */ + +// TODO: @botbw: remove redundant code with gemm_py.h + +#ifndef TVM_TL_OP_GEMM_SP_PY_H_ +#define TVM_TL_OP_GEMM_SP_PY_H_ + +#include "gemm_sp.h" +#include "operator.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + +class GemmSPPyNode : public TileOperatorNode { +public: + bool CheckWGMMA() const; + tir::Buffer A, E, B, C; + // pointer to the A, E, B, C + BufferRegion aRegion_, eRegion_, bRegion_, cRegion_; + bool trans_A, trans_B, trans_E; + int M, N, K; + int stride_A, stride_B; + int offset_A, offset_B; + PrimExpr clear_accum = const_false(); + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack = 1; + int wg_wait = 0; + + // use GemmWarp Policy here as the atom size are flexible in v2 + mutable GemmWarpPolicy policy; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("A", &GemmSPPyNode::A) + .def_ro("E", &GemmSPPyNode::E) + .def_ro("B", &GemmSPPyNode::B) + .def_ro("C", &GemmSPPyNode::C) + .def_ro("aRegion", &GemmSPPyNode::aRegion_) + .def_ro("eRegion", &GemmSPPyNode::eRegion_) + .def_ro("bRegion", &GemmSPPyNode::bRegion_) + .def_ro("cRegion", &GemmSPPyNode::cRegion_) + .def_ro("trans_A", &GemmSPPyNode::trans_A) + .def_ro("trans_B", &GemmSPPyNode::trans_B) + .def_ro("trans_E", &GemmSPPyNode::trans_E) + .def_ro("M", &GemmSPPyNode::M) + .def_ro("N", &GemmSPPyNode::N) + .def_ro("K", &GemmSPPyNode::K) + .def_ro("stride_A", &GemmSPPyNode::stride_A) + .def_ro("stride_B", &GemmSPPyNode::stride_B) + .def_ro("offset_A", &GemmSPPyNode::offset_A) + .def_ro("offset_B", &GemmSPPyNode::offset_B) + .def_ro("clear_accum", &GemmSPPyNode::clear_accum) + .def_ro("kPack", &GemmSPPyNode::kPack) + .def_ro("wg_wait", &GemmSPPyNode::wg_wait) + .def_ro("policy", &GemmSPPyNode::policy); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + +private: + // Target GEMM instruction + GemmInst GetGemmInst(int block_size, Target target) const; + + mutable bool completed_ = false; +}; + +class GemmSPPy : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator, + GemmSPPyNode); + TVM_DLL GemmSPPy(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_SP_PY_H_ \ No newline at end of file diff --git a/tilelang/original/src/op/logical.cc b/tilelang/original/src/op/logical.cc new file mode 100644 index 0000000000000000000000000000000000000000..0de6658bdad5597bf683901dd14447467c8a589c --- /dev/null +++ b/tilelang/original/src/op/logical.cc @@ -0,0 +1,55 @@ +/*! + * \file tl/op/logical.cc + * \brief Logical operations. + * + */ + +#include +#include +#include +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace tl { +using namespace tir; + +PrimExpr any_of_op(PrimExpr args) { + const CallNode *call = args.as(); + CHECK(call != nullptr); + const Array &arg = call->args; + ICHECK_EQ(arg.size(), 2); + PrimExpr buffer_address = arg[0]; + PrimExpr elems = arg[1]; + return tir::Call(DataType::Bool(), tir::builtin::call_extern(), + {StringImm("tl::Any"), buffer_address, elems}); +} + +PrimExpr all_of_op(PrimExpr args) { + const CallNode *call = args.as(); + CHECK(call != nullptr); + const Array &arg = call->args; + ICHECK_EQ(arg.size(), 2); + PrimExpr buffer_address = arg[0]; + PrimExpr elems = arg[1]; + return tir::Call(DataType::Bool(), tir::builtin::call_extern(), + {StringImm("tl::All"), buffer_address, elems}); +} + +TVM_REGISTER_OP("tl.any_of") + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)) + .set_attr("TScriptPrinterName", "any_of") + .set_attr("cuda.FLowerIntrinsic", any_of_op); + +TVM_REGISTER_OP("tl.all_of") + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)) + .set_attr("TScriptPrinterName", "all_of") + .set_attr("cuda.FLowerIntrinsic", all_of_op); + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/math.cc b/tilelang/original/src/op/math.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9de966ea8255caf5925170efcc9bc59a69cfd89 --- /dev/null +++ b/tilelang/original/src/op/math.cc @@ -0,0 +1,67 @@ +/*! + * \file tl/op/math.cc + * \brief Math operations. + * + */ + +#include +#include +#include +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace tl { +using namespace tir; + +PrimExpr pow_of_int_op(PrimExpr args) { + const CallNode *call = args.as(); + CHECK(call != nullptr); + const Array &arg = call->args; + ICHECK_EQ(arg.size(), 2); + PrimExpr base = arg[0]; + PrimExpr exp = arg[1]; + String pow_of_int_name = + "tl::pow_of_int<" + std::to_string(exp.as()->value) + ">"; + return tir::Call(base.dtype(), tir::builtin::call_extern(), + {StringImm(pow_of_int_name), base}); +} + +TVM_REGISTER_OP("tl.pow_of_int") + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)) + .set_attr("TScriptPrinterName", "pow_of_int") + .set_attr("hip.FLowerIntrinsic", pow_of_int_op) + .set_attr("cuda.FLowerIntrinsic", pow_of_int_op); + +PrimExpr infinity_op(PrimExpr args) { + const CallNode *call = args.as(); + CHECK(call != nullptr); + const DataType &dtype = call->dtype; + ICHECK_EQ(dtype.lanes(), 1); + + // NOTE(wt): Codegen for PrintConst:Inf will handle this based on dtype + if (dtype.is_float()) { + if (dtype.bits() == 64 || dtype.bits() == 32 || dtype.bits() == 16) { + return FloatImm(dtype, std::numeric_limits::infinity(), + call->span); + } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::infinity(), call->span); + } + LOG(FATAL) << "Cannot decide infinity for type " << dtype; + throw; // Unreachable, keeps compiler happy +} + +TVM_REGISTER_OP("tl.infinity") + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)) + .set_attr("TScriptPrinterName", "infinity") + .set_attr("cuda.FLowerIntrinsic", infinity_op) + .set_attr("hip.FLowerIntrinsic", infinity_op); + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/operator.cc b/tilelang/original/src/op/operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..302ee3e37443d4ccf7655ecb26577848e8afd3e4 --- /dev/null +++ b/tilelang/original/src/op/operator.cc @@ -0,0 +1,81 @@ +/*! + * \file tl/op/op.cc + * + * Define operators usd in tile library. + */ + +#include "operator.h" + +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +/** + * @brief Construct a TileOperator from a TIR Call using a registered builder. + * + * Looks up a builder function in the "TLOpBuilder" Op attribute map for the + * operator referenced by `call` and invokes it to produce a TileOperator. If no + * builder is registered for the operator, returns a default-constructed (empty) + * TileOperator. + * + * @param call The TIR Call whose operator and arguments will be used to build + * the TileOperator. + * @return TileOperator The constructed TileOperator, or a default (empty) + * TileOperator if no builder exists. + */ +TileOperator ParseOperator(Call call) { + auto op_map = Op::GetAttrMap("TLOpBuilder"); + Op op = call->op.as().value(); + if (op_map.count(op)) { + auto tile_op = op_map[op](call->args); + ICHECK(tile_op.defined()); + return tile_op; + } + return TileOperator(); +} + +/** + * @brief Parse a TileOperator from a TIR statement if it contains a call. + * + * If `stmt` is an Evaluate node whose value is a Call, delegates to + * ParseOperator(Call, BufferMap) and returns the resulting TileOperator. + * Otherwise returns a default-constructed (empty) TileOperator. + * + * @param stmt TIR statement to inspect; expected to be an Evaluate of a Call. + * @return TileOperator Parsed operator on success, or a default (empty) + * TileOperator if `stmt` is not an Evaluate(Call). + */ +TileOperator ParseOperator(Stmt stmt) { + if (stmt.as() && stmt.as()->value.as()) { + auto call = stmt.as()->value.as(); + return ParseOperator(tvm::ffi::GetRef(call)); + } + return TileOperator(); +} + +/** + * @brief Extracts the Var referenced by a `tvm_access_ptr` call expression. + * + * The function expects `expr` to be a `Call` to the builtin `tvm_access_ptr` + * and returns the `Var` found in the call's second argument (`args[1]`). The + * function performs runtime checks and will abort if `expr` is not a call, the + * call is not `tvm_access_ptr`, or the second argument is not a `Var`. + * + * @param expr A `PrimExpr` representing a `tvm_access_ptr(...)` call. + * @return tvm::Var The `Var` referenced by the `tvm_access_ptr` call. + */ +Var GetVarFromAccessPtr(const PrimExpr &expr) { + auto call = expr.as(); + ICHECK(call); + ICHECK(call->op.same_as(builtin::tvm_access_ptr())); + auto var = call->args[1].as(); + ICHECK(var); + return tvm::ffi::GetRef(var); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/operator.h b/tilelang/original/src/op/operator.h new file mode 100644 index 0000000000000000000000000000000000000000..c246864e40291cecd0df71f43233474ec4d7ceca --- /dev/null +++ b/tilelang/original/src/op/operator.h @@ -0,0 +1,99 @@ +/*! + * \file tl/op/op.h + * \brief Tile library operations. + * + */ + +#ifndef TVM_TL_OP_OP_H_ +#define TVM_TL_OP_OP_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "../layout/layout.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +using AddWorkspaceCallback = std::function; +using LayoutMap = Map; +using BufferMap = Map; + +enum class InferLevel : uint8_t { + kFree = 0, + kCommon = 1, + kStrict = 2, +}; + +struct LowerArgs { + Target target; + Range thread_bounds; + Var thread_var; + AddWorkspaceCallback AddWorkspace; + LayoutMap layout_map; + Map buffer_remap; + // Map from LetStmt variable to its bound expression, for resolving + // fragment buffer accesses through let bindings + Map let_var_to_expr; +}; + +struct LayoutInferArgs { + Target target; + Range thread_bounds; + LayoutMap layout_map; + arith::Analyzer *analyzer; + bool buffer_oob = false; + Map buffer_remap; + // Map from LetStmt variable to its bound expression, for resolving + // fragment buffer accesses through let bindings + Map let_var_to_expr; +}; + +class TileOperator; + +class TileOperatorNode : public Object { +public: + virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0; + + virtual LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const = 0; + + virtual TileOperator Clone() const = 0; + + TVM_FFI_DECLARE_OBJECT_INFO("tl.TileOperator", TileOperatorNode, Object); +}; + +class TileOperator : public ObjectRef { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileOperator, ObjectRef, + TileOperatorNode); +}; + +Var GetVarFromAccessPtr(const PrimExpr &expr); + +TileOperator ParseOperator(Call call); +TileOperator ParseOperator(Stmt stmt); + +using OpBuilderFunc = ffi::TypedFunction)>; + +#define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \ + const Op &Entry::Get() { \ + static const Op &op = Op::Get("tl.tileop." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tl.tileop." #OpName) \ + .set_attr("TScriptPrinterName", #OpName) \ + .set_attr( \ + "TLOpBuilder", [](Array args) { return Entry(args); }) + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_OP_H_ diff --git a/tilelang/original/src/op/parallel.cc b/tilelang/original/src/op/parallel.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbc6ea8e24539a21756371ac456fabf44f2dcb7b --- /dev/null +++ b/tilelang/original/src/op/parallel.cc @@ -0,0 +1,718 @@ +/*! + * \file op/parallel.cc + * \brief Define Parallel for operator + */ + +#include "parallel.h" + +#include +#include + +#include "../layout/utils.h" +#include "../target/utils.h" +#include "../transform/loop_partition.h" +#include "../transform/loop_vectorize.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +namespace attr { +/*! \brief Mark that how the loop is vectorized. */ +constexpr const char *coalesced_width = "coalesced_width"; +} // namespace attr + +// ProveFragmentContains checks whether the threads that access elements of a +// smaller fragment (small_frag) are a subset of the threads that access +// elements of a larger fragment (large_frag) for any given loop index. This +// function ensures that if the small fragment's layout corresponds to the loop +// itself, accessing the large fragment's elements is valid. Additionally, if +// small is updated to large, the originally valid access remains valid. The +// proof is performed by: +// +// 1. Defining a variable `rep_small` to represent the replicate index of the +// small fragment that is being checked. +// 2. Using the `small_frag_indices` and `rep_small` to derive the thread +// accessing +// the element in the small fragment. +// 3. Using `large_frag_indices` to derive the physical index of the large +// fragment +// along with the thread information, and then feeding these into the inverse +// of the large fragment to obtain the logical index and replicate index. +// 4. Verifying the mapping by checking whether the computed thread using the +// inverse +// layout corresponds to the original thread calculated for the small +// fragment. If they don't match, this indicates that the inverse layout's +// domain does not include the thread and thus the access is invalid. +bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, + Array small_frag_indices, + Array large_frag_indices, + arith::Analyzer &analyzer_) { + Var rep_small("__checking_frag_contains_rep"); + analyzer_.Bind(rep_small, + Range(IntImm(small_frag->ReplicateExtent()->dtype, 0), + small_frag->ReplicateExtent()), + true); // Bind the replicate extent of small_frag. + // Derive thread for small_frag. + auto thread = small_frag->ForwardThread(small_frag_indices, rep_small); + + // Get physical index and thread for large_frag. + auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices); + // Add small_frag's thread to the large fragment's thread info. + large_frag_physical_and_thread.push_back(thread); + // Get the inverse of the large fragment. + auto inv_large_frag = large_frag->Inverse(); + // Compute logical index and replicate index using inverse layout. + auto inv_large_frag_logical_and_rep = + inv_large_frag->Forward(large_frag_physical_and_thread); + + // Extract replicate index from the result. + auto inv_large_frag_rep = + inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1]; + + // Calculate thread based on the logical index and replicate index. + auto check_thread = + large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep); + + // Simplify the difference between the threads. + auto diff = analyzer_.Simplify(thread - check_thread); + // If the difference is zero, the threads match and the access is valid. + return is_zero(diff); +} + +class IfBufferRemapLoopGenerator : public StmtExprMutator { +public: + static For run(Stmt stmt, Map buffer_remap, + Map layout_map) { + IfBufferRemapLoopGenerator generator(buffer_remap, layout_map); + return Downcast(generator(std::move(stmt))); + } + +private: + IfBufferRemapLoopGenerator(Map buffer_remap, + Map layout_map) + : buffer_remap_(buffer_remap), layout_map_(layout_map) {} + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (buffer_remap_.count(load->buffer)) { + auto new_indices = layout_map_[load->buffer]->Forward(load->indices); + auto new_buffer = buffer_remap_[load->buffer]; + + return BufferLoad(new_buffer, new_indices); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + if (buffer_remap_.count(store->buffer)) { + auto new_indices = layout_map_[store->buffer]->Forward(store->indices); + auto new_buffer = buffer_remap_[store->buffer]; + return BufferStore(new_buffer, store->value, new_indices); + } + return store; + } + + Map buffer_remap_; + Map layout_map_; +}; + +/** + * @brief Handle a parallel For node during traversal, collecting loop metadata. + * + * Visits a parallel loop, asserts the loop is parallel, records a data-parallel + * IterVar for the loop, binds the loop variable range into the analyzer scope, + * and extracts any reducer information from the loop's annotations into the + * visitor's reducer_info_map_. Continues traversal into the loop body. + */ +void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { + if (op->kind == ForKind::kParallel) + p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, + IterVarType::kDataPar)); + else + p->inner_vars_.Set(op->loop_var, + IterVar(Range(op->min, op->extent), op->loop_var, + IterVarType::kOrdered)); + p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + auto reducer_info_map = + op->annotations.Get(attr::kReducerInfo)->as>(); + if (reducer_info_map) { + for (auto &&[buffer, info] : reducer_info_map.value()) + p->reducer_info_map_.Set(buffer, info); + } + StmtExprVisitor::VisitStmt_(op); +} + +void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { + if (op->buffer.scope() == "local.fragment") { + if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { + ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) + << op->buffer << ": " << op->indices << " and " + << p->indice_map_.at(op->buffer); + } else { + p->indice_map_.Set(op->buffer, op->indices); + } + p->buffer_is_write_.insert(op->buffer); + } + StmtExprVisitor::VisitStmt_(op); +} + +void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { + if (op->buffer.scope() == "local.fragment") { + if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { + ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) + << op->buffer << ": " << op->indices << " and " + << p->indice_map_.at(op->buffer); + } else { + p->indice_map_.Set(op->buffer, op->indices); + } + } + StmtExprVisitor::VisitExpr_(op); +} + +ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { + V.VisitStmt(root); +} + +TileOperator ParallelOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return ParallelOp(op); +} + +void ParallelOpNode::ExpandLetBindings( + const Map &let_var_to_expr) { + if (let_var_to_expr.empty()) + return; + + // Helper function to recursively find BufferLoads through let bindings + std::function expand = [&](const PrimExpr &expr) { + PostOrderVisit(expr, [&](const ObjectRef &node) { + if (auto bl = node.as()) { + if (bl->buffer.scope() == "local.fragment" && + !indice_map_.count(bl->buffer)) { + indice_map_.Set(bl->buffer, bl->indices); + } + } else if (auto var_node = node.as()) { + auto var = tvm::ffi::GetRef(var_node); + if (let_var_to_expr.count(var)) { + expand(let_var_to_expr[var]); + } + } + }); + }; + + // Scan all let bindings + for (const auto &[var, expr] : let_var_to_expr) { + expand(expr); + } +} + +Stmt ParallelOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + return root_; +} + +bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const { + auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); + return StructuralEqual()(indice_map_[buffer], common_indice); +} + +/*! \brief Infer the layout for parallel operations based on different inference + * levels + * + * The inference level controls how aggressively we try to infer and optimize + * layouts: + * - kStrict (2): Most conservative level. Only allows explicitly defined + * layouts. Returns empty layout map if loop_layout_ is not already defined. + * Used when exact layout control is required. + * + * - kCommon (1): Intermediate level between strict and free. + * Allows common layout patterns while maintaining some + * constraints. + * + * - kFree (0): Most permissive level. Allows maximum optimization freedom. + * Will attempt layout inference even without source buffers. + * Can generate new layouts based on vectorization and thread + * bounds. Used when maximum performance optimization is desired. + */ +LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (loop_layout_.defined()) + return {}; + + // Expand let bindings to find fragment buffer accesses + if (!T.let_var_to_expr.empty()) { + const_cast(this)->ExpandLetBindings(T.let_var_to_expr); + } + + if (level == InferLevel::kStrict) { + LayoutMap results; + // Deduce buffers that should be complicated replicated. + // For example: + // for i in T.Parallel(m): + // fragment[0] = x[i] + // then fragment[0] must be replicated on all threads. + for (const auto &[buffer, indices] : indice_map_) { + if (T.layout_map.count(buffer)) { + continue; + } + if (buffer.scope() != "local.fragment") + continue; + + // Check if all indices are zero + bool all_indices_zero = true; + for (const auto &index : indices) { + if (const auto *imm = index.as()) { + if (imm->value != 0) { + all_indices_zero = false; + LOG(FATAL) + << "Fragment buffer access with non-zero index [" << imm->value + << "] is not supported. " + << "Only fragment[0] access is allowed within T.Parallel loop."; + } + } else { + // Non-constant index, not all zero + all_indices_zero = false; + } + } + + // Only set layout if all indices are zero + if (all_indices_zero) { + Array forward_vars; + for (const auto &s : buffer->shape) { + forward_vars.push_back( + IterVar(Range(0, s), Var(), IterVarType::kDataPar)); + } + Var rep; + auto rep_iter = + IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar); + + // Use default fragment indexing (single output dim) to + // stay consistent with other ops (e.g., ReduceOp), and + // bind the thread range for comparability. + const PrimExpr &forward_thread = rep; + auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread, + rep_iter) + ->BindThreadRange(T.thread_bounds); + results.Set(buffer, frag); + } + } + return results; + } + auto buffer_is_completed_replicated = [&](const Buffer &buffer) { + if (buffer.scope() != "local.fragment") + return false; + auto frag = T.layout_map[buffer].as().value(); + // buffer indices should be IntImm + for (const auto &index : indice_map_[buffer]) { + if (!index.as()) { + return false; + } else if (index.as()->value != 0) { + LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; + } + } + return frag->IsCompletedReplicated(); + }; + // Collect fragment buffers with const index and all fragment_buffers + std::vector const_index_fragment_buffer, fragment_buffers; + for (const auto &[buffer, indices] : indice_map_) { + if (buffer.scope() != "local.fragment") + continue; + fragment_buffers.push_back(buffer); + + bool is_const_index = true; + for (const auto &index : indices) { + if (!index.as()) { + is_const_index = false; + break; + } + } + if (is_const_index) { + const_index_fragment_buffer.push_back(buffer); + } + } + + // Determine if common layout propagation should be applied. + // If there are fragment buffers with non-constant indices, we need to + // propagate the common layout pattern to ensure consistency across all + // fragments. Example cases: + // - Need propagation: frag_a[0] = T.min(frag_a[0], frag_b[i]) + // (const index frag_a interacts with non-const index frag_b) + // - No propagation needed: shared_a[i] = frag_a[0] + // (const index frag_a with non-fragment buffer) + + bool allow_layout_propgate = + const_index_fragment_buffer.empty() || + (fragment_buffers.size() > const_index_fragment_buffer.size()); + + // Step 1: try to infer loop's partition from a source fragment + Buffer source_buffer, read_source_buffer; + Buffer replicated_write_buffer; // Backup: fully replicated write buffer + + for (const auto &[buffer, indices] : indice_map_) { + if (T.layout_map.count(buffer)) { + // skip reducers with rep=ALL + if (auto info = reducer_info_map_.Get(buffer->data); + info && info.value()->rep == ReducerRepType::ALL) + continue; + + auto frag = T.layout_map[buffer].as().value(); + bool is_fully_replicated = buffer_is_completed_replicated(buffer); + + if (buffer_is_write_.count(buffer)) { + source_buffer = buffer; + } else { + // Keep the buffer with largest number of indices + // (which means the inference based on that buffer is more accurate) + // as read_source_buffer to get more accurate layout + // if the buffer is completed replicated, we don't need to infer the + // layout from this buffer. + if ((!read_source_buffer.defined() || + indice_map_[buffer].size() > + indice_map_[read_source_buffer].size())) { + read_source_buffer = buffer; + } + // If the buffer is not replicated and shape is equal to the + // source_buffer, use it as source_buffer because the layout inference + // is more accurate + if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) { + source_buffer = buffer; + } + } + } + } + auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) { + Fragment src_layout = T.layout_map[buffer].as().value(); + DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `" + << buffer << "` of layout " << src_layout->DebugOutput() << '\n'; + + Fragment result; + if (IsCommonAccessIndice(buffer)) { + result = src_layout; + } else { + Var rep; + auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, + IterVarType::kDataPar); + PrimExpr loop_var_to_thread = + src_layout->ForwardThread(indice_map_[buffer], rep); + loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); + PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { + if (auto opt_var = objref.as(); + opt_var && inner_vars_.count(*opt_var)) { + std::ostringstream oss; + oss << "loop_var_to_thread = " << loop_var_to_thread + << "contains inner var" << *opt_var; + throw LayoutConflictException(oss.str()); + } + }); + + try { + result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) + ->BindThreadRange(T.thread_bounds); + } catch (const tvm::runtime::Error &err) { + std::ostringstream msg; + msg << "Layout inference for buffer `" << buffer->name + << "` failed inside `T.parallel` loop."; + + msg << "\nUnderlying TVM error: " << err.what(); + msg << "\nProblematic loop AST:\n " << root_; + msg << "\nHint: ensure the loop extent divides the thread binding or " + "adjust the fragment mapping."; + LOG(FATAL) << msg.str(); + } + } + DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " + << result->DebugOutput() << '\n'; + return result; + }; + + // Try to infer loop layout from buffers in order of preference: + // 1. Non-replicated write buffer (most reliable) + // 2. Non-replicated read buffer + // 3. Fully replicated write buffer (backup, may cause issues) + // 4. Free inference mode (no source buffer) + + if (source_buffer.defined() && allow_layout_propgate) { + loop_layout_ = compute_loop_layout_from_buffer(source_buffer); + } else if (level == InferLevel::kFree) { + // For free layout inference + // If replication exists and buffer has cross-thread shared memory access, + // add predicate + bool has_cross_thread_access = false; + PostOrderVisit(root_, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + // check if scope is shared or global + if (store->buffer.scope() == "shared" || + store->buffer.scope() == "shared.dyn" || + store->buffer.scope() == "global") { + has_cross_thread_access = true; + } + } else if (const auto *load = obj.as()) { + // check if scope is shared or global + if (load->buffer.scope() == "shared" || + load->buffer.scope() == "shared.dyn" || + load->buffer.scope() == "global") { + has_cross_thread_access = true; + } + } + }); + + // check if loop body contains a "pure" buffer store (i.e., direct + // assignment, not compound update) + std::vector store_shared_global_buffers, store_fragment_buffers; + // Buffers that scope is above fragments. + // global, shared, shared.dyn + // which can be used to analysis replicate case + PostOrderVisit(root_, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + auto buffer = store->buffer; + if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" || + buffer.scope() == "global") { + store_shared_global_buffers.emplace_back(buffer); + } else if (buffer.scope() == "local.fragment") { + store_fragment_buffers.emplace_back(buffer); + } + } + }); + if (read_source_buffer.defined() && allow_layout_propgate) { + loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); + } + + if (!loop_layout_.defined()) { + // No source buffer available, use free mode inference + // Vectorize Size must be aware of the buffer_remap + // As the pass will do post processing to the layout + auto maybe_remapped_root_ = + IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); + int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); + DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; + + PrimExpr loop_total_size = 1; + for (Stmt l = root_; l.as().has_value(); + l = l.as().value()->body) + loop_total_size = loop_total_size * l.as().value()->extent; + DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size + << '\n'; + while (!analyzer_.CanProve( + floormod(loop_total_size, + T.thread_bounds->extent * vector_size) == 0) && + vector_size > 1) + vector_size /= 2; + DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = " + << vector_size << '\n'; + + // Check if coalesced_width is defined + if (auto coalesced_width = + root_->annotations.Get(tl::attr::coalesced_width)) { + if (const auto *imm = coalesced_width->as()) { + int expected = imm->value; + // Verify that vector_size is divisible by expected + if (vector_size % expected != 0) { + LOG(FATAL) << "Vector size " << vector_size + << " is not divisible by coalesced width " << expected; + } + vector_size = expected; + } else { + LOG(FATAL) << "coalesced_width should be an IntImmNode."; + } + } + DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ + << " ############# vector_size = " << vector_size + << ", thread_bounds = " << T.thread_bounds << '\n'; + loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds); + DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = " + << loop_layout_->DebugOutput() << '\n'; + } + + // Lambda that guards replicated accesses: + // - When a loop layout replicates a fragment buffer (rep > 1), each thread + // observes the same fragment elements. Blindly storing to shared/global + // memory in that case would add the same value multiple times. + // - We therefore restrict the store so that only the replica with rep == 0 + // performs the update (e.g. global[i] += fragment[i] only fires once). + // Trigger conditions for this guard: + // 1) There are cross-thread stores targeting shared/global memory (no + // fragment stores in this branch; atomic_add and similar remain TODO). + // 2) The loop layout replicate extent is greater than 1, inferred from the + // thread bounds captured in the layout. + + [this, &store_shared_global_buffers, &store_fragment_buffers, + &has_cross_thread_access, &const_index_fragment_buffer, &T]() { + if (is_one(loop_layout_->ReplicateExtent())) + return; + if (!has_cross_thread_access) + return; + + if (!store_fragment_buffers.empty()) { + // Iterate replicated fragment stores: when the fragment index is a + // constant (e.g. fragment[0]), every thread touches the same slot, so + // the rep == 0 predicate is unnecessary. Example: for i in + // T.Parallel(...): + // shared[i] = ... + // fragment[0] = ... + bool replicate_is_from_dynamic_index_fragment = false; + for (const auto &fragment : store_fragment_buffers) { + if (!T.layout_map.count(fragment)) { + continue; + } + + auto fragment_layout = T.layout_map[fragment].as().value(); + if (is_one(fragment_layout->ReplicateExtent())) + continue; + + if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(), + loop_layout_->ReplicateExtent())) + continue; + if (std::find(const_index_fragment_buffer.begin(), + const_index_fragment_buffer.end(), + fragment) == const_index_fragment_buffer.end()) { + replicate_is_from_dynamic_index_fragment = true; + } + } + + if (!replicate_is_from_dynamic_index_fragment) + return; + + ICHECK(store_shared_global_buffers.empty()) + << "Invalid layout: cannot have both fragment and shared store " + "buffers " + "in replicated loop layout."; + return; + } else { + // Now, store is global or shared + // or T.call_extern or T.call_intrin ... + auto inv = loop_layout_->Inverse(); + Array fwd; + for (size_t i = 0; i < loop_layout_->OutputDim(); i++) + fwd.push_back(0); + fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min); + auto rep = inv->Forward(fwd).back(); + AddPredicate(EQ(rep, 0)); + } + }(); + } else { + return {}; + } + // check loop_layout_ is injective + auto injective_res = loop_layout_->DetectInjective(); + if (!injective_res->errors.empty()) { + std::ostringstream oss; + oss << "Loop layout is not injective: " << loop_layout_->DebugOutput() + << '\n' + << " errors: " << injective_res->errors << '\n' + << " loop AST: " << root_; + throw LoopLayoutInjectiveException(oss.str()); + } + + PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); + + auto block_size = T.thread_bounds->extent; + if (loop_layout_.defined()) { + if (loop_layout_->ThreadRange().defined()) { + auto thread_range = loop_layout_->ThreadRange(); + block_size = thread_range->extent; + AddPredicate(GE(InputPlaceholder(0), thread_range->min)); + AddPredicate( + LT(InputPlaceholder(0), thread_range->min + thread_range->extent)); + } + } + + if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) { + AddPredicate( + LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min)); + } + + // Step 2: Check that the loop's partition can correctly align with all source + // fragment, and infer layout only when it's not yet layout-ed + LayoutMap results; + for (const auto &[buffer, _] : indice_map_) { + if (T.layout_map.count(buffer)) { + auto fragment = T.layout_map[buffer].as().value(); + auto vars = + loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); + if (!ProveFragmentContains(loop_layout_, fragment, vars, + indice_map_[buffer], analyzer_)) { + std::ostringstream oss; + oss << "Layout infer conflict between " << buffer << " and " + << source_buffer << " in T.Parallel loop:" << '\n' + << " loop " << loop_layout_->DebugOutput() << '\n' + << " fragment " << fragment->DebugOutput() << '\n'; + throw LayoutConflictException(oss.str()); + } + } else { + auto dst_layout = + CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds); + results.Set(buffer, dst_layout); + } + } + return results; +} + +Optional ParallelOpNode::GetPredicate(Var thread_var) const { + if (predicate_.defined()) { + return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); + } else { + return std::nullopt; + } +} + +Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { + ICHECK(loop_layout_.defined()); + if (IsCommonAccessIndice(buffer)) { + return loop_layout_; + } + // Prefer a simple path: if original 2D indices form a bijective map, invert + // them directly and avoid introducing a synthetic replicate dimension. + { + auto res2d = + arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1, + arith::IterMapLevel::Bijective, + const_cast(&analyzer_)); + if (res2d->errors.empty()) { + Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse(); + PrimExpr indice_rep_extent = 1; + PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; + Array fwd2; + for (size_t i = 0; i < buffer->shape.size(); i++) { + fwd2.push_back(InputPlaceholder(i)); + } + PrimExpr thd_b2 = + loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt); + return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent, + std::nullopt) + ->CondenseReplicateVar(); + } + } + // Otherwise, infer an extra flattened iterator that captures truly-unused + // pieces of the loop space (if any), then try inversion with it. + PrimExpr rep_b = MakeFlattenedExpression( + DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); + auto bijective_indice = indice_map_[buffer]; + bijective_indice.push_back(rep_b); + Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); + + PrimExpr indice_rep_extent = + ind_inv->InputShape().back(); // this is the size of rep_b + PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; + Array fwd; + for (size_t i = 0; i < buffer->shape.size(); i++) { + fwd.push_back(InputPlaceholder(i)); + } + fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); + PrimExpr thd_b = loop_layout_->ForwardThread( + ind_inv->Forward(fwd), + FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); + return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, + std::nullopt) + ->CondenseReplicateVar(); +} + +TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); } + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/parallel.h b/tilelang/original/src/op/parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..88dd1debf7697e58af4535e524a2865510cce821 --- /dev/null +++ b/tilelang/original/src/op/parallel.h @@ -0,0 +1,146 @@ +/*! + * \file tl/op/parallel.h + * \brief Infer layout from ops and parallel for + */ + +#ifndef TVM_TL_OP_PARALLEL_H_ +#define TVM_TL_OP_PARALLEL_H_ + +#include +#include + +#include "../layout/layout.h" +#include "../transform/layout_reducer.h" +#include "./operator.h" + +/** + * Conjoin `expr` into the operator's predicate (logical AND). If no predicate + * exists yet, `expr` becomes the predicate. + * + * @param expr Predicate expression to add. + */ +namespace tvm { +namespace tl { + +using namespace tir; + +bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, + Array small_frag_indices, + Array large_frag_indices, + arith::Analyzer &analyzer_); + +class ParallelOpNode; + +class ParallelLoopNestVisitor : public StmtExprVisitor { +private: + ParallelLoopNestVisitor(ParallelOpNode *op) : p(op) {}; + void VisitStmt_(const ForNode *op) override; + void VisitStmt_(const BufferStoreNode *op) override; + void VisitExpr_(const BufferLoadNode *op) override; + + ParallelOpNode *p; + + friend class ParallelOpNode; +}; + +// ParallelOpNode represents a parallel for loop operator in TileLang. +// It is responsible for inferring layouts, holding loop structure, and managing +// predicates. +class ParallelOpNode : public TileOperatorNode { +public: + // The root For loop node. + For root_; + // The inferred layout for the loop, mutable to allow lazy inference. + mutable Fragment loop_layout_; + // The predicate expression for the loop, if any, mutable for lazy + // construction. + mutable Optional predicate_; + + // Type key for TVM object system. + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ParallelOp", ParallelOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("root", &ParallelOpNode::root_) + .def_ro("loop_layout", &ParallelOpNode::loop_layout_) + .def_ro("predicate", &ParallelOpNode::predicate_); + } + + // Construct from a root For loop. + ParallelOpNode(For root); + + // Lower the operator to a TIR statement. + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + + // Infer the layout for this parallel operator. + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + // Copy constructor for ParallelOpNode. + ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) { + loop_layout_ = other.loop_layout_; + predicate_ = other.predicate_; + } + + // Get the inferred loop layout. + Fragment GetLoopLayout() const { return loop_layout_; } + // Get the root For loop. + For GetRoot() const { return root_; } + // Get the mapping from buffer to access indices. + Map> GetIndiceMap() const { return indice_map_; } + // Get the predicate for a given thread variable. + Optional GetPredicate(Var thread_var) const; + + // Clone this operator. + TileOperator Clone() const override; + +private: + // Complete the fragment layout for a given buffer. + Fragment CompleteBufferFragment(const Buffer &buffer) const; + // Check if the buffer is accessed with common indices (i.e., loop variables). + bool IsCommonAccessIndice(const Buffer &buffer) const; + // Add a predicate to the current predicate expression. + void AddPredicate(const PrimExpr &expr) const { + predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; + } + // Expand let bindings to find fragment buffer accesses and add them to + // indice_map_. This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0], + // ...) + void ExpandLetBindings(const Map &let_var_to_expr); + + // Allow ParallelLoopNestVisitor to access private members. + friend class ParallelLoopNestVisitor; + + // Visitor for collecting loop nest information. + ParallelLoopNestVisitor V; + // Mapping from buffer to their access indices in the loop. + Map> indice_map_; + // Set of buffers that are written to in the loop. + std::unordered_set buffer_is_write_; + // The loop variables for the parallel loop nest. + Array loop_vars_; + // The inner_vars_ + Map inner_vars_; + // Analyzer for simplifying and analyzing expressions, mutable for lazy use. + mutable arith::Analyzer analyzer_; + // Mapping from buffer to reducer info. + Map reducer_info_map_; +}; + +class ParallelOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParallelOp, TileOperator, + ParallelOpNode); + + ParallelOp(const For &root) { + auto op = tvm::ffi::make_object(root); + data_ = std::move(op); + } +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_PARALLEL_H_ diff --git a/tilelang/original/src/op/reduce.cc b/tilelang/original/src/op/reduce.cc new file mode 100644 index 0000000000000000000000000000000000000000..4458a4f51864a13d8680416f0f16d4669b047fe6 --- /dev/null +++ b/tilelang/original/src/op/reduce.cc @@ -0,0 +1,584 @@ +/*! + * \file tl/op/reduce.cc + * \brief Implementation of reduction operators + */ + +#include "reduce.h" + +#include +#include +#include +#include + +#include "../layout/utils.h" +#include "../op/parallel.h" +#include "../target/utils.h" +#include "../transform/loop_partition.h" +#include "tir/transforms/ir_utils.h" +#include "tvm/tir/stmt.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} + +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} + +ReduceOp::ReduceOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + // Accept BufferRegion/BufferLoad for src/dst + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); + node->src = node->srcRegion_->buffer; + node->dst = node->dstRegion_->buffer; + std::string reduce_type = args[2].as().value()->value; + node->dim = args[3].as().value()->value; + node->type = ReduceType(reduce_type); + node->clear = args[4].as().value(); + data_ = std::move(node); +} + +TileOperator ReduceOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return ReduceOp(op); +} + +TileOperator CumSumOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return CumSumOp(op); +} + +PrimExpr ReduceOpNode::MakeInitValue() const { + auto dst_dtype = dst->dtype; + auto is_int = dst_dtype.is_int(); + bool is_uint = dst_dtype.is_uint(); + auto bits = dst_dtype.bits(); + + if (type->isSum()) { + return make_zero(dst->dtype); + } else if (type->isAbsSum()) { + return make_zero(dst->dtype); + } else if (type->isMax()) { + if (is_int) { + return make_const(dst->dtype, -(1 << (bits - 1))); + } else if (is_uint) { + return make_const(dst->dtype, 0); + } else { + return make_const(dst->dtype, -INFINITY); + } + } else if (type->isMin()) { + if (is_int) { + return make_const(dst->dtype, (1 << (bits - 1)) - 1); + } else if (is_uint) { + return make_const(dst->dtype, (1 << bits) - 1); + } else { + return make_const(dst->dtype, INFINITY); + } + } else if (type->isAbsMax()) { + return make_const(dst->dtype, 0); + } else if (type->isBitAnd()) { + if (is_int) { + return make_const(dst->dtype, -1); + } else if (is_uint) { + return make_const(dst->dtype, (1 << bits) - 1); + } else { + // Should not arrive here + return make_const(dst->dtype, -INFINITY); + } + } else if (type->isBitOr()) { + return make_zero(dst->dtype); + } else if (type->isBitXor()) { + return make_zero(dst->dtype); + } else { + LOG(FATAL) << "Unsupported reduce type: " << type->type; + return PrimExpr(); + } +} + +PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs, + const PrimExpr &b) const { + PrimExpr rhs = b; + if (lhs->dtype != rhs->dtype) { + rhs = Cast(lhs->dtype, rhs); + } + if (type->isSum()) { + return lhs + rhs; + } else if (type->isAbsSum()) { + return lhs + Max(rhs, -rhs); + } else if (type->isMax()) { + return Max(lhs, rhs); + } else if (type->isMin()) { + return Min(lhs, rhs); + } else if (type->isAbsMax()) { + return Max(tvm::abs(lhs), tvm::abs(rhs)); + } else if (type->isBitAnd()) { + return lhs & rhs; + } else if (type->isBitOr()) { + return lhs | rhs; + } else if (type->isBitXor()) { + return lhs ^ rhs; + } else { + LOG(FATAL) << "Unsupported reduce type: " << type->type; + } +} + +std::string ReduceOpNode::MakeCodegenReducer() const { + if (type->isSum()) { + return "tl::SumOp"; + } else if (type->isAbsSum()) { + return "tl::SumOp"; + } else if (type->isMax()) { + return "tl::MaxOp"; + } else if (type->isMin()) { + return "tl::MinOp"; + } else if (type->isAbsMax()) { + return "tl::MaxOp"; + } else if (type->isBitAnd()) { + return "tl::BitAndOp"; + } else if (type->isBitOr()) { + return "tl::BitOrOp"; + } else if (type->isBitXor()) { + return "tl::BitXorOp"; + } else { + LOG(FATAL) << "Unsupported reduce type: " << type->type; + return ""; + } +} + +/** + * @brief Lower the Reduce operator to a TIR statement. + * + * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of + * TIR statements implementing: optional initialization, thread-local reduction + * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call + * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true), + * and an optional accumulation or copy back to the destination buffer when a + * temporary clear buffer is used. + * + * Behavior notes: + * - Only supports src and dst in "local.fragment" scope; otherwise it checks + * and aborts with "Reduce for shared memory not implemented.". + * - Supports both 1D reductions (scalar output) and reductions along a single + * extra dimension; validates layout dimensionality consistency. + * - If `clear` is set (or for sum/abssum reductions), an initial value is + * written to the clear buffer; for non-clearing sum/abssum a duplicate + * temporary buffer is allocated and accumulated back into dst after + * reduction. + * - Performs iterator compression for local reduction loops using `analyzer`. + * - Detects parallel thread splitting from the normalized iterator sum and + * emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`) + * via `builtin::call_extern`. For sufficiently large reducing thread counts + * (>= 32) a workspace is allocated via T.AddWorkspace and passed to the + * AllReduce call. + * - The final body is wrapped in parallel loops over the destination spatial + * dimensions and partitioned by the lowering thread variable. If a temporary + * clear buffer is used, it is allocated for the body. + * + * @param T Lowering context providing buffer and layout maps, thread bounds, + * target information, thread variable, and workspace allocation + * helper. + * @param analyzer Analyzer used for iterator compression and arithmetic + * normalization. + * @return Stmt Lowered TIR statement implementing the reduction. + */ +Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + auto get_buffer = [&](const Buffer &buf) { + if (T.buffer_remap.count(buf)) + return T.buffer_remap[buf]; + return buf; + }; + + auto src_scope = this->src.scope(); + auto dst_scope = this->dst.scope(); + + if (src_scope == "local.fragment" && dst_scope == "local.fragment") { + + Buffer src_buffer = get_buffer(this->src); + Buffer dst_buffer = get_buffer(this->dst); + Fragment src_layout = T.layout_map[this->src].as().value(); + Fragment dst_layout = T.layout_map[this->dst].as().value(); + size_t src_dim = src_layout->InputDim(); + size_t dst_dim = dst_layout->InputDim(); + + bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1; + + if (is_1d_reduce) { + ICHECK(is_one(dst_layout->OutputShape().back())) + << "Reduce for scalar not implemented."; + } else { + ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch."; + } + + Array dst_vars; + for (size_t i = 0; i < dst_dim; ++i) { + Var var = Var(std::string{char('i' + i)}); + dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, + IterVarType::kDataPar)); + } + + Array src_vars; + if (!is_1d_reduce) { + src_vars = dst_vars; + } + Range reduce_dom(0, src_layout->InputShape()[this->dim]); + IterVar reduce_iv(reduce_dom, Var("rv"), IterVarType::kDataPar); + src_vars.insert(src_vars.begin() + this->dim, reduce_iv); + + Array src_indices = src_layout->Forward( + src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); + Array dst_indices = dst_layout->Forward( + dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); + + Array stmts; + + bool require_init = this->clear; + if (this->type->isSum() || this->type->isAbsSum() || + this->type->isBitAnd() || this->type->isBitOr() || + this->type->isBitXor()) { + require_init = true; + } + + Buffer clear_buffer = dst_buffer; + bool need_duplicate = false; + if ((this->type->isSum() || this->type->isAbsSum()) && !this->clear) { + need_duplicate = true; + } else if (this->type->isBitAnd() && !this->clear) { + need_duplicate = true; + } else if ((this->type->isBitOr() || this->type->isBitXor()) && + !this->clear) { + need_duplicate = true; + } + + if (need_duplicate) { + // Create a new buffer with same shape and dtype as dst_buffer + clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype, + dst_buffer->name + "_clear", + GetPtrStorageScope(dst_buffer->data)); + } + // make reduce-init stmt + if (require_init) { + stmts.push_back( + BufferStore(clear_buffer, this->MakeInitValue(), dst_indices)); + } + + // make thread-local reduce + Array src_indice_compressed; + Array src_var_compressed; + for (size_t i = 0; i < src_layout->OutputDim(); ++i) { + PrimExpr expr; + IterVar var; + std::tie(expr, var) = CompressIterator( + src_indices[i], src_vars, src_vars[this->dim]->var, analyzer); + src_indice_compressed.push_back(expr); + src_var_compressed.push_back(var); + } + + Stmt reduce_local = BufferStore( + clear_buffer, + this->MakeReduce(BufferLoad(clear_buffer, dst_indices), + BufferLoad(src_buffer, src_indice_compressed)), + dst_indices); + + for (int i = static_cast(src_layout->OutputDim()) - 1; i >= 0; --i) { + reduce_local = + For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, + ForKind::kUnrolled, reduce_local, std::nullopt, + {{tir::attr::pragma_unroll_explicit, Bool(false)}}); + } + stmts.push_back(reduce_local); + + PrimExpr src_thread = src_layout->ForwardThread( + src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {}); + auto iter_sum = + arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); + for (const auto &iter_split : iter_sum->args) { + auto mark = iter_split->source->source.as(); + ICHECK(mark) << "Not a normalized iterator: " << iter_split->source; + if (mark.value().same_as(src_vars[this->dim]->var)) { + auto scale = as_const_int(iter_split->scale); + auto extent = as_const_int(iter_split->extent); + ICHECK(scale != nullptr && extent != nullptr); + if (*extent == 1) + continue; + + int reducing_threads = (*extent) * (*scale); + std::stringstream ss; + + auto thread_offset = T.thread_bounds->min; + if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) { + auto all_threads = T.thread_bounds->extent; + ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " + << reducing_threads << ", " << (*scale) << ", " << thread_offset + << ", " << all_threads << ">::run_hopper"; + } else { + ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " + << reducing_threads << ", " << (*scale) << ", " << thread_offset + << ">::run"; + } + Array thread_reduce_args = { + StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)}; + if (reducing_threads >= 32) { + PrimExpr workspace = T.AddWorkspace( + *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); + thread_reduce_args.push_back(workspace); + } + auto call = Call(clear_buffer->dtype, builtin::call_extern(), + thread_reduce_args); + stmts.push_back(BufferStore(clear_buffer, call, dst_indices)); + } + } + + if (need_duplicate) { + PrimExpr src_val = BufferLoad(clear_buffer, dst_indices); + PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); + PrimExpr update; + if (this->type->isSum() || this->type->isAbsSum()) { + update = dst_val + src_val; + } else if (this->type->isBitAnd()) { + update = this->clear ? src_val : bitwise_and(dst_val, src_val); + } else if (this->type->isBitOr()) { + update = bitwise_or(dst_val, src_val); + } else if (this->type->isBitXor()) { + update = bitwise_xor(dst_val, src_val); + } else { + LOG(FATAL) << "Unsupported reduce type: " << this->type->type; + } + stmts.push_back(BufferStore(dst_buffer, update, dst_indices)); + } + + Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; + for (int i = static_cast(dst_layout->InputDim()) - 1; i >= 0; --i) { + body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, + ForKind::kParallel, body); + } + + if (dst_layout->InputDim() > 0) { + body = PartitionLoop(Downcast(body), T.thread_var, analyzer, + dst_layout); + } else { + PrimExpr guard = (T.thread_var == T.thread_bounds->min); + body = IfThenElse(guard, body); + } + + if (need_duplicate) { + body = Allocate(clear_buffer->data, clear_buffer->dtype, + clear_buffer->shape, const_true(), body); + } + return body; + } + + LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", " + << dst_scope << ") is not implemented."; + return Stmt(); +} + +LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (level >= InferLevel::kStrict) + return {}; + + if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && + T.layout_map.count(src)) { + auto src_layout = T.layout_map[src].as().value(); + + PrimExpr indice_rep_extent = src->shape[dim]; + PrimExpr src_rep_extent = src_layout->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent; + + Array fwd; + for (int i = 0; i < static_cast(src->shape.size()); i++) { + if (i == dim) { + fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); + } else if (i < dim) { + fwd.push_back(InputPlaceholder(i)); + } else if (i > dim) { + fwd.push_back(InputPlaceholder(i - 1)); + } + } + auto thd = src_layout->ForwardThread( + fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); + + // Ensure the thread count is divisible by the replicate extent. + // Otherwise, we cannot infer a valid fragment<->fragment layout. + { + arith::Analyzer analyzer; + PrimExpr num_threads = T.thread_bounds->extent; + // Though the dest_buffer_rep_extent will be compressed at + // CondenseReplicateVar, we need to check the divisibility here to avoid + // the issue that the thread count is not divisible by the replicate + // extent. + if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) == + 0) && + !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) == + 0)) { + ICHECK(false) << "ReduceOp fragment layout inference failed: " + "num_threads % replicate_extent != 0. " + << "This mapping requires the block's thread count to be " + "divisible by the " + << "replicate extent. " + << "Try one of: (1) choose a thread block size divisible " + "by replicate_extent; " + << "(2) pick a different reduce dimension or adjust the " + "source fragment layout; " + << "Details: num_threads=" << num_threads + << ", replicate_extent=" << indice_rep_extent + << ", src=" << src << ", dst=" << dst; + } + } + + Fragment dst_layout = + Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) + ->CondenseReplicateVar() + ->BindThreadRange(T.thread_bounds); + + if (!T.layout_map.count(dst)) + return {{dst, dst_layout}}; + else { + // Check if computed layout is compatible with existing: the existing one + // must strictly contains the computed layout + auto orig_dst_layout = + T.layout_map.Get(dst).value().as().value(); + ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim()); + Array indices; + indices.reserve(dst_layout->InputDim()); + arith::Analyzer inner_analyzer; + for (int i = 0; i < dst_layout->InputDim(); ++i) { + auto x = InputPlaceholder(i); + indices.push_back(x); + // should be literal - literal = 0, any analyzer will work + ICHECK(is_zero(inner_analyzer.Simplify( + dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i]))); + inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); + } + + ICHECK(as_const_int(dst_layout->ReplicateExtent())); + ICHECK(as_const_int(src_layout->ReplicateExtent())); + auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); + auto src_rep = *as_const_int(src_layout->ReplicateExtent()); + if (dst_rep < src_rep || + !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices, + inner_analyzer)) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << orig_dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); + } + + if (dst_rep > src_rep) { + return {{dst, dst_layout}}; + } + } + } + return {}; +} + +TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the +// ranges. +static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); +} + +CumSumOp::CumSumOp(Array args) { + /// CumSum constructor arguments: + /// - src: input buffer + /// - dst: output buffer + /// - dim: dimension to cumsum + /// - reverse: whether to cumsum in reverse order + CHECK_EQ(args.size(), 4); + ObjectPtr node = tvm::ffi::make_object(); + // node->src = vmap[GetVarFromAccessPtr(args[0])]; + // node->dst = vmap[GetVarFromAccessPtr(args[1])]; + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); + node->src = node->srcRegion_->buffer; + node->dst = node->dstRegion_->buffer; + node->dim = args[2].as().value()->value; + node->reverse = args[3].as().value(); + CHECK_LT(node->dim, static_cast(node->src->shape.size())) + << "The dim of cumsum should be less than the number of dimensions. Got " + "dim=" + << node->dim << ", but src has " << node->src->shape.size() << " dims."; + + data_ = std::move(node); +} + +Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + if (this->src.scope() == "local.fragment" && + this->dst.scope() == "local.fragment") { + LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue " + "if you need this feature."; + } else if (this->src.scope() == "shared.dyn" || + this->src.scope() == "shared") { + ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared"); + std::stringstream ss; + auto threads = T.thread_bounds->extent; + Array args; + + // Build access pointers from regions locally + PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1); + PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2); + + // Use region extents instead of buffer shape for correct slice handling + Array src_extents; + for (const auto &range : srcRegion_->region) { + src_extents.push_back(range->extent); + } + int ndim = static_cast(src_extents.size()); + + if (ndim == 1) { + ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " + "= 0."; + ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") + << ">::run"; + args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0]}; + } else if (ndim == 2) { + ss << "tl::CumSum2D<" << threads << ", " << dim << ", " + << (reverse ? "true" : "false") << ">::run"; + args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0], + src_extents[1]}; + } else { + LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " + << ndim << "D."; + } + return Evaluate(Call(dst->dtype, builtin::call_extern(), args)); + } else { + ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and " + << this->dst.scope(); + } + + return Stmt(); +} + +LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + return {}; +} + +TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + ReduceOpNode::RegisterReflection(); + CumSumOpNode::RegisterReflection(); + ReduceTypeNode::RegisterReflection(); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/reduce.h b/tilelang/original/src/op/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..cab3835e19899ba65e9f0e68237b16a813d640f6 --- /dev/null +++ b/tilelang/original/src/op/reduce.h @@ -0,0 +1,173 @@ +/*! + * \file tl/op/reduce.h + * \brief Reduction operators for tensor computations + */ + +#ifndef TVM_TL_OP_REDUCE_H_ +#define TVM_TL_OP_REDUCE_H_ + +#include "operator.h" + +namespace tvm { + +namespace tl { + +using namespace tir; + +/// Supported reduction operation types +enum class ReduceTypeEnum : uint8_t { + kSum, ///< Sum reduction + kAbsSum, ///< Absolute sum reduction + kMax, ///< Maximum value reduction + kMin, ///< Minimum value reduction + kAbsMax, ///< Maximum absolute value reduction + kBitAnd, ///< Bitwise and reduction + kBitOr, ///< Bitwise or reduction + kBitXor, ///< Bitwise xor reduction +}; + +/// Node class representing a reduction type +class ReduceTypeNode : public Object { +public: + int type{-1}; ///< Internal type identifier + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceType", ReduceTypeNode, Object); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("type", &ReduceTypeNode::type); + } + + /// Type checking methods + bool isSum() const { return type == int(ReduceTypeEnum::kSum); } + bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); } + bool isMax() const { return type == int(ReduceTypeEnum::kMax); } + bool isMin() const { return type == int(ReduceTypeEnum::kMin); } + bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); } + bool isBitAnd() const { return type == int(ReduceTypeEnum::kBitAnd); } + bool isBitOr() const { return type == int(ReduceTypeEnum::kBitOr); } + bool isBitXor() const { return type == int(ReduceTypeEnum::kBitXor); } +}; + +/// Wrapper class for reduction type with string-based construction +class ReduceType : public ObjectRef { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceType, ObjectRef, + ReduceTypeNode); + TVM_DLL ReduceType(std::string type) { + auto node = tvm::ffi::make_object(); + if (type == "sum") { + node->type = int(ReduceTypeEnum::kSum); + } else if (type == "abssum") { + node->type = int(ReduceTypeEnum::kAbsSum); + } else if (type == "max") { + node->type = int(ReduceTypeEnum::kMax); + } else if (type == "absmax") { + node->type = int(ReduceTypeEnum::kAbsMax); + } else if (type == "min") { + node->type = int(ReduceTypeEnum::kMin); + } else if (type == "bitand") { + node->type = int(ReduceTypeEnum::kBitAnd); + } else if (type == "bitor") { + node->type = int(ReduceTypeEnum::kBitOr); + } else if (type == "bitxor") { + node->type = int(ReduceTypeEnum::kBitXor); + } else { + LOG(FATAL) << "Invalid reduce type: " << type; + } + data_ = std::move(node); + } +}; + +/// Node class for reduction operations +class ReduceOpNode : public TileOperatorNode { +public: + tir::Buffer src, dst; ///< Source and destination buffers + // Optional: keep the original regions used to construct this op + BufferRegion srcRegion_, dstRegion_; + int dim; ///< Dimension to reduce along + ReduceType type; ///< Type of reduction operation + bool clear; ///< Whether to clear destination before reduction + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &ReduceOpNode::src) + .def_ro("dst", &ReduceOpNode::dst) + .def_ro("srcRegion", &ReduceOpNode::srcRegion_) + .def_ro("dstRegion", &ReduceOpNode::dstRegion_) + .def_ro("dim", &ReduceOpNode::dim) + .def_ro("type", &ReduceOpNode::type) + .def_ro("clear", &ReduceOpNode::clear); + } + + /// Lower the operator to TIR statements + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + /// Infer memory layout for buffers + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + static const Op &Get(); + TileOperator Clone() const; + +private: + /// Generate initial value for reduction + PrimExpr MakeInitValue() const; + /// Generate reduction expression + PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const; + /// Generate codegen reducer string + std::string MakeCodegenReducer() const; +}; + +/// Wrapper class for reduction operations +class ReduceOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator, + ReduceOpNode); + TVM_DLL ReduceOp(Array args); + static const Op &Get(); +}; + +/// Node class for cumulative sum operations +class CumSumOpNode : public TileOperatorNode { +public: + tir::Buffer src, dst; ///< Source and destination buffers + // Optional: keep the original regions used to construct this op + BufferRegion srcRegion_, dstRegion_; + int dim; ///< Dimension along which to compute cumulative sum + bool reverse; ///< Whether to compute in reverse order + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &CumSumOpNode::src) + .def_ro("dst", &CumSumOpNode::dst) + .def_ro("srcRegion", &CumSumOpNode::srcRegion_) + .def_ro("dstRegion", &CumSumOpNode::dstRegion_) + .def_ro("dim", &CumSumOpNode::dim) + .def_ro("reverse", &CumSumOpNode::reverse); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + static const Op &Get(); + TileOperator Clone() const; +}; + +/// Wrapper class for cumulative sum operations +class CumSumOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator, + CumSumOpNode); + TVM_DLL CumSumOp(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_REDUCE_H_ diff --git a/tilelang/original/src/op/region.cc b/tilelang/original/src/op/region.cc new file mode 100644 index 0000000000000000000000000000000000000000..25e78eba89ef7a112976e347b8919053ef34bc9f --- /dev/null +++ b/tilelang/original/src/op/region.cc @@ -0,0 +1,87 @@ +/*! + * \file tl/op/region.cc + * \brief Define region operator (bridge to carry BufferRegion via Call args). + * + * Notes: + * - BufferLoad/Ramp cannot represent a general PrimExpr as a vector lane + * count. Dynamic extents like (H1 - H0) cannot be encoded as + * Ramp(lanes = H1 - H0), and lowering BufferRegion to BufferLoad loses the + * explicit extent information. + * - tl.region carries both mins and extents in Call args and lets the backend + * reconstruct a BufferRegion faithfully. + */ + +#include "region.h" +#include + +namespace tvm { +namespace tl { +using namespace tir; + +RegionOp::RegionOp(Array args) { + size_t n = args.size(); + size_t ndim = n - 2; + auto load = args[0].as(); + ICHECK(load); + ICHECK(load->indices.size() == ndim) + << "load->indices.size() = " << load->indices << " ndim = " << ndim; + Array ranges; + // Rebuild per-axis ranges from mins (BufferLoad indices) and provided extents + for (size_t i = 0; i < ndim; i++) { + PrimExpr index = load->indices[i]; + PrimExpr extent = args[2 + i]; + if (const auto *ramp = index.as()) { + const auto *stride_imm = ramp->stride.as(); + ICHECK(stride_imm && stride_imm->value == 1) + << "RegionOp expects stride-1 Ramp for index"; + if (const auto *lanes_imm = ramp->lanes.as()) { + if (const auto *ext_imm = extent.as()) { + ICHECK_EQ(lanes_imm->value, ext_imm->value) + << "Ramp lanes and provided extent must match"; + } + } + ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + } else { + ranges.push_back(Range::FromMinExtent(index, extent)); + } + } + ObjectPtr node = tvm::ffi::make_object(); + node->buffer_ = load->buffer; + node->access_mask_ = static_cast(*as_const_int(args[1])); + node->ranges_ = ranges; + data_ = std::move(node); +} + +TileOperator RegionOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return RegionOp(op); +} + +bool RegionOpNode::IsFullRegion() const { + for (size_t i = 0; i < ranges_.size(); i++) { + if (!is_zero(ranges_[i]->min)) + return false; + if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) + return false; + } + return true; +} + +Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + return Evaluate(0); +} + +LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + return {}; +} + +TIR_REGISTER_TL_TILE_OP(RegionOp, region) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TVM_FFI_STATIC_INIT_BLOCK() { RegionOpNode::RegisterReflection(); } + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/region.h b/tilelang/original/src/op/region.h new file mode 100644 index 0000000000000000000000000000000000000000..24399f7ab551d80c7f64a91c94061063e0520913 --- /dev/null +++ b/tilelang/original/src/op/region.h @@ -0,0 +1,91 @@ +/*! + * \file tl/op/region.h + * \brief Tile memory region descriptor op (bridge to carry BufferRegion via + * Call args). + * + * Why tl.region instead of passing BufferRegion directly? + * + * - While TIR can represent a BufferRegion, when a BufferRegion is passed as a + * call argument through call_intrin/FFI, the Python->C++ conversion lowers it + * to a BufferLoad(indices). To encode an interval inside indices, the FFI + * typically uses Ramp(base, stride, lanes) to represent a contiguous slice. + * - Ramp(lanes) may only be a constant or vscale*k (scalable vector). A general + * PrimExpr (e.g., H1 - H0) is not allowed as lanes, so dynamic extents would + * make the lowered BufferLoad invalid. + * - Moreover, BufferLoad only carries indices, not per-axis extents. Downstream + * tile operators (e.g., tl.copy, tl.reduce) that require both min and extent + * cannot losslessly recover dynamic extents from a BufferLoad alone. + * + * tl.region is a small transport-only op that solves this: + * - The frontend packs buffer + mins (from BufferLoad.indices) + extents into + * Call args, allowing dynamic extents to be expressed explicitly. + * - The backend (NormalizeToBufferRegion) reconstructs a BufferRegion from the + * tl.region call without losing information. + * - The op itself carries no semantics in Lower/InferLayout and is only used as + * a bridge for argument passing. + */ + +#ifndef TVM_TL_OP_REGION_H_ +#define TVM_TL_OP_REGION_H_ + +#include "./operator.h" +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class RegionOpNode : public TileOperatorNode { +public: + Buffer buffer_; + Array ranges_; + int access_mask_; + + /*! + * access_mask_ encodes the intended access type when the region is used as + * an argument to tile operators: 1=read, 2=write, 3=read-write. The mask is + * transport metadata only and does not affect lowering. + */ + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode, + TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + const Buffer &GetBuffer() const { return buffer_; } + const Array &GetRanges() const { return ranges_; } + int GetAccessMask() const { return access_mask_; } + bool IsFullRegion() const; + + TileOperator Clone() const override; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &RegionOpNode::buffer_) + .def_ro("ranges", &RegionOpNode::ranges_) + .def_ro("access_mask", &RegionOpNode::access_mask_); + } +}; + +class RegionOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator, + RegionOpNode); + /*! + * Build a RegionOp from call arguments: + * - args[0]: BufferLoad whose indices are per-axis minima. + * - args[1]: Integer access mask (1=r, 2=w, 3=rw). + * - args[2 + i]: Extent of axis i (supports dynamic PrimExpr). + */ + TVM_DLL RegionOp(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_REGION_H_ diff --git a/tilelang/original/src/op/tcgen5_meta.h b/tilelang/original/src/op/tcgen5_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..8b6ff61ba3725928958c1bfa1c785264e74ee7b9 --- /dev/null +++ b/tilelang/original/src/op/tcgen5_meta.h @@ -0,0 +1,177 @@ +#ifndef TVM_TL_OP_TCGEN5_META_H_ +#define TVM_TL_OP_TCGEN5_META_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tl { + +using runtime::DataType; + +struct TCGEN5MMAMeta { + int atom_m, atom_n, atom_k; + bool enable_ws, enable_2cta; +}; + +inline std::pair +GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { +// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. +#define FAIL \ + return { \ + false, TCGEN5MMAMeta { 0, 0, 0, false, false } \ + } +#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \ + } + std::vector ws_valid_atom_ns = {256, 128, 64}; + if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 16 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 16, false, false); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 16, true, false); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 16, true, false); + FAIL; + } else { + FAIL; + } + } else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() || + ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) && + ((c_dtype.is_float() && c_dtype.bits() == 32) || + (c_dtype.is_float16() && c_dtype.bits() == 16))) { + if (K % 32 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, true, false); + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, true); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, false); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 32, true, false); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 32, false, false); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 32, true, false); + FAIL; + } else { + FAIL; + } + } + FAIL; +#undef FAIL +#undef SUCCESS +} + +inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, + DataType ab_dtype, DataType c_dtype, + bool a_is_k_major, bool b_is_k_major, + int scale_in_a, int scale_in_b) { + ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; + ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; + ICHECK(atom_k == 16 || atom_k == 32) + << "Unsupported atom_k for TCGEN5MMA descriptor: " << atom_k; + ICHECK(scale_in_a == 1 || scale_in_a == -1) + << "scale_in_a must be +/-1 for TCGEN5MMA"; + ICHECK(scale_in_b == 1 || scale_in_b == -1) + << "scale_in_b must be +/-1 for TCGEN5MMA"; + + auto encode_dtype = [&](DataType dtype) -> uint32_t { + if (dtype.is_float16()) { + return static_cast(0); + } else if (dtype.is_bfloat16()) { + return static_cast(1); + } else if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || + dtype.is_float8_e4m3()) { + return static_cast(0); + } else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) { + return static_cast(1); + } + LOG(FATAL) << "Unsupported dtype for TCGEN5MMA descriptor: " << dtype; + return 0u; + }; + + uint32_t a_format = encode_dtype(ab_dtype); + uint32_t b_format = a_format; + + uint32_t c_format = 0; + if (c_dtype.is_float16()) { + c_format = 0; + } else if (c_dtype.is_float()) { + c_format = 1; + } else if (c_dtype.is_int()) { + c_format = 2; + } else { + LOG(FATAL) << "Unsupported accumulator dtype for TCGEN5MMA descriptor: " + << c_dtype; + } + + auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { + uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); + return (value & mask) << start; + }; + + uint32_t desc = 0; + desc |= set_bits(0, 0, 2); // sparse_id2 + desc |= set_bits(0, 2, 1); // sparse_flag + desc |= set_bits(0, 3, 1); // saturate + desc |= set_bits(c_format, 4, 2); + + desc |= set_bits(a_format, 7, 3); + desc |= set_bits(b_format, 10, 3); + + uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; + uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; + desc |= set_bits(a_neg, 13, 1); + desc |= set_bits(b_neg, 14, 1); + + uint32_t a_major = a_is_k_major ? 0u : 1u; + uint32_t b_major = b_is_k_major ? 0u : 1u; + desc |= set_bits(a_major, 15, 1); + desc |= set_bits(b_major, 16, 1); + + uint32_t n_dim = static_cast(atom_n >> 3); + uint32_t m_dim = static_cast(atom_m >> 4); + desc |= set_bits(n_dim, 17, 6); + desc |= set_bits(0, 23, 1); + desc |= set_bits(m_dim, 24, 5); + desc |= set_bits(0, 29, 1); + + uint32_t max_shift = 0u; + desc |= set_bits(max_shift, 30, 2); + + return desc; +} + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_TCGEN5_META_H_ diff --git a/tilelang/original/src/op/utils.cc b/tilelang/original/src/op/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e56ae8c7a51ef519e047fbb50f857b2f4673bc0 --- /dev/null +++ b/tilelang/original/src/op/utils.cc @@ -0,0 +1,96 @@ +/*! + * \file tl/op/utils.cc + * \brief Common utilities implementation for TL ops. + */ + +#include "utils.h" + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { + // Case 1: Already a BufferRegion + if (arg->IsInstance()) { + return Downcast(arg); + } + + // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else + // extent=1) + if (const auto *load = arg.as()) { + Array ranges; + for (const PrimExpr &index : load->indices) { + if (const auto *ramp = index.as()) { + ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; + ICHECK_EQ(ramp->stride.as()->value, 1) + << "Only stride-1 Ramp is supported in region conversion"; + ICHECK(ramp->lanes.as()) + << "Scalable vector lanes not supported in region conversion"; + ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + } else { + ranges.push_back(Range::FromMinExtent(index, 1)); + } + } + return BufferRegion(load->buffer, ranges); + } + + // Case 3: tl.region(...) — reconstruct via RegionOp (bridge) + if (const auto *call = arg.as()) { + if (call->op.same_as(RegionOp::Get())) { + RegionOp region(call->args); + return BufferRegion(region->GetBuffer(), region->GetRanges()); + } + LOG(FATAL) << "Unsupported argument for BufferRegion (expect " + "BufferLoad/BufferRegion/tl.region): " + << arg; + } + + LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg; + throw; // Unreachable +} + +PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, + bool require_2d) { + Buffer buf = region->buffer; + int ndim = static_cast(buf->shape.size()); + if (require_2d) { + ICHECK(ndim >= 2) << "Expect buffers with at least 2 dims"; + } + + PrimExpr offset, extent; + if (ndim == 1) { + // 1D: straightforward + auto axis = region->region[0]; + offset = axis->min; + extent = axis->extent; + } else { + // Compute row-major strides + std::vector strides(ndim); + PrimExpr one = make_const(buf->shape[0].dtype(), 1); + PrimExpr cur = one; + for (int i = ndim - 1; i >= 0; --i) { + strides[i] = cur; + cur = cur * buf->shape[i]; + } + // Offset: sum_{i in [0..ndim-3]} min_i * stride_i + offset = make_const(buf->shape[0].dtype(), 0); + for (int i = 0; i < ndim - 2; ++i) { + offset = offset + region->region[i]->min * strides[i]; + } + // Extent: last two extents product (elements) + extent = + region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; + } + + // ptype and return handle + PrimExpr ptype = tir::TypeAnnotation(buf->dtype); + Array acc_args{ptype, buf->data, offset, extent, + IntImm(DataType::Int(32), rw_mask)}; + return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/op/utils.h b/tilelang/original/src/op/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d386b1a58d74b22bddb24f52b2d05120d316890d --- /dev/null +++ b/tilelang/original/src/op/utils.h @@ -0,0 +1,35 @@ +/*! + * \file tl/op/utils.h + * \brief Common utilities for TL ops. + */ + +#ifndef TVM_TL_OP_UTILS_H_ +#define TVM_TL_OP_UTILS_H_ + +#include "./operator.h" +#include "region.h" +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// Normalize an argument (BufferRegion/BufferLoad/tl.region) +// to BufferRegion so ops can uniformly consume regions. +// Note: tvm_access_ptr is no longer supported here. +TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); + +// Build a tvm_access_ptr(handle) from a BufferRegion. +// - If `require_2d` is true, checks buffer ndim >= 2. +// - For 1D regions (when allowed), offset=min, extent=extent. +// - For ndim >= 2, offset sums all but last two dims using row-major strides, +// extent is product of the last two extents. +TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, + int rw_mask, bool require_2d = false); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_UTILS_H_ diff --git a/tilelang/original/src/runtime/error_helpers.cc b/tilelang/original/src/runtime/error_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..903f8b1d9ec8a5dafb6a093609848c5dbd4832b1 --- /dev/null +++ b/tilelang/original/src/runtime/error_helpers.cc @@ -0,0 +1,222 @@ +/* + * Helper functions for nicer runtime error messages. + */ +#include "error_helpers.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tl { + +// Return non-zero so that tvm_call_packed sites treat it as failure and return +// -1. +static int DTypeMismatch(const tvm::ffi::String &kernel_name, + const tvm::ffi::String &buffer_name, + int64_t actual_code, int64_t actual_bits, + int64_t actual_lanes, int64_t expect_code, + int64_t expect_bits, int64_t expect_lanes) { + tvm::runtime::DataType actual(static_cast(actual_code), + static_cast(actual_bits), + static_cast(actual_lanes)); + tvm::runtime::DataType expect(static_cast(expect_code), + static_cast(expect_bits), + static_cast(expect_lanes)); + std::ostringstream os; + os << "kernel " << std::string(kernel_name) << " input " + << std::string(buffer_name) << " dtype expected " << expect << ", but got " + << actual; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + return -1; +} + +// Variant without names, to avoid passing extra raw strings through packed +// args. +static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits, + int64_t actual_lanes, int64_t expect_code, + int64_t expect_bits, int64_t expect_lanes) { + tvm::runtime::DataType actual(static_cast(actual_code), + static_cast(actual_bits), + static_cast(actual_lanes)); + tvm::runtime::DataType expect(static_cast(expect_code), + static_cast(expect_bits), + static_cast(expect_lanes)); + std::ostringstream os; + os << "dtype mismatch: expected " << expect << ", but got " << actual; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + return -1; +} + +// Register packed versions, following the design in runtime.cc +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + + // Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name, + // actual_code, actual_bits, actual_lanes, + // expect_code, expect_bits, expect_lanes) + refl::GlobalDef().def_packed( + tl::tvm_error_dtype_mismatch, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 8) << "Expected 8 args: kernel, buffer, " + "actual_code, actual_bits, actual_lanes, " + << "expect_code, expect_bits, expect_lanes"; + + auto kernel_name = args[0].cast(); + auto buffer_name = args[1].cast(); + int64_t actual_code = args[2].cast(); + int64_t actual_bits = args[3].cast(); + int64_t actual_lanes = args[4].cast(); + int64_t expect_code = args[5].cast(); + int64_t expect_bits = args[6].cast(); + int64_t expect_lanes = args[7].cast(); + + // Reuse the helper to format the message + (void)DTypeMismatch(kernel_name, buffer_name, actual_code, actual_bits, + actual_lanes, expect_code, expect_bits, + expect_lanes); + // Provide a return value for completeness, then signal the error + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // kernel, buffer, expect:int64, got:int64 + refl::GlobalDef().def_packed( + tl::tvm_error_ndim_mismatch, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 4) + << "__tvm_error_ndim_mismatch(kernel, buffer, expect, got)"; + auto kernel = args[0].cast(); + auto buffer = args[1].cast(); + int64_t expect = args[2].cast(); + int64_t got = args[3].cast(); + std::ostringstream os; + os << "kernel " << std::string(kernel) << " input " + << std::string(buffer) << " ndim expected " << expect << ", but got " + << got; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // kernel, buffer, expect:int64, got:int64 + refl::GlobalDef().def_packed( + tl::tvm_error_byte_offset_mismatch, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 4) + << "__tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)"; + auto kernel = args[0].cast(); + auto buffer = args[1].cast(); + int64_t expect = args[2].cast(); + int64_t got = args[3].cast(); + std::ostringstream os; + os << "kernel " << std::string(kernel) << " input " + << std::string(buffer) << " byte_offset expected " << expect + << ", but got " << got; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // kernel, buffer, expect:int64, got:int64 + refl::GlobalDef().def_packed( + tl::tvm_error_device_type_mismatch, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 4) + << "__tvm_error_device_type_mismatch(kernel, buffer, expect, got)"; + auto kernel = args[0].cast(); + auto buffer = args[1].cast(); + int64_t expect = args[2].cast(); + int64_t got = args[3].cast(); + const char *expect_str = + tvm::runtime::DLDeviceType2Str(static_cast(expect)); + const char *got_str = + tvm::runtime::DLDeviceType2Str(static_cast(got)); + std::ostringstream os; + os << "kernel " << std::string(kernel) << " input " + << std::string(buffer) << " device_type expected " << expect_str + << ", but got " << got_str; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // kernel, buffer, field:String + refl::GlobalDef().def_packed( + tl::tvm_error_null_ptr, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 3) + << "__tvm_error_null_ptr(kernel, buffer, field)"; + auto kernel = args[0].cast(); + auto buffer = args[1].cast(); + auto field = args[2].cast(); + std::ostringstream os; + os << "kernel " << std::string(kernel) << " input " + << std::string(buffer) << ' ' << std::string(field) + << " expected non-NULL, but got NULL"; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // kernel, buffer, field:String, expect:int64, got:int64 + refl::GlobalDef().def_packed( + tl::tvm_error_expect_eq, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 5) + << "__tvm_error_expect_eq(kernel, buffer, field, expect, got)"; + auto kernel = args[0].cast(); + auto buffer = args[1].cast(); + auto field = args[2].cast(); + int64_t expect = args[3].cast(); + int64_t got = args[4].cast(); + std::ostringstream os; + os << "kernel " << std::string(kernel) << " input " + << std::string(buffer) << ' ' << std::string(field) << " expected " + << expect << ", but got " << got; + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // kernel, buffer, field:String [, reason:String] + refl::GlobalDef().def_packed( + tl::tvm_error_constraint_violation, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + ICHECK(args.size() == 3 || args.size() == 4) + << "__tvm_error_constraint_violation(kernel, buffer, field[, " + "reason])"; + auto kernel = args[0].cast(); + auto buffer = args[1].cast(); + auto field = args[2].cast(); + std::string reason; + if (args.size() == 4) { + reason = args[3].cast(); + } + std::ostringstream os; + os << "kernel " << std::string(kernel) << " input " + << std::string(buffer) << ' ' << std::string(field) + << " constraint not satisfied"; + if (!reason.empty()) { + os << ": " << reason; + } + TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); + *ret = -1; + throw ::tvm::ffi::EnvErrorAlreadySet(); + }); + + // Legacy typed registrations for backward compatibility + refl::GlobalDef().def("tilelang_error_dtype_mismatch", + &tvm::tl::DTypeMismatch); + refl::GlobalDef().def("tilelang_error_dtype_mismatch2", + &tvm::tl::DTypeMismatchNoNames); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/runtime/error_helpers.h b/tilelang/original/src/runtime/error_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..6620d837e9031d795a16ebc40c1a81d709717904 --- /dev/null +++ b/tilelang/original/src/runtime/error_helpers.h @@ -0,0 +1,27 @@ +/*! + * \file tl/runtime/error_helpers.h + * \brief Error helper FFI names for TileLang runtime. + */ + +#ifndef TVM_TL_RUNTIME_ERROR_HELPERS_H_ +#define TVM_TL_RUNTIME_ERROR_HELPERS_H_ + +namespace tvm { +namespace tl { + +// Error helper packed functions +constexpr const char *tvm_error_dtype_mismatch = "__tvm_error_dtype_mismatch"; +constexpr const char *tvm_error_ndim_mismatch = "__tvm_error_ndim_mismatch"; +constexpr const char *tvm_error_byte_offset_mismatch = + "__tvm_error_byte_offset_mismatch"; +constexpr const char *tvm_error_device_type_mismatch = + "__tvm_error_device_type_mismatch"; +constexpr const char *tvm_error_null_ptr = "__tvm_error_null_ptr"; +constexpr const char *tvm_error_expect_eq = "__tvm_error_expect_eq"; +constexpr const char *tvm_error_constraint_violation = + "__tvm_error_constraint_violation"; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_RUNTIME_ERROR_HELPERS_H_ diff --git a/tilelang/original/src/runtime/runtime.cc b/tilelang/original/src/runtime/runtime.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2a7127d264a709626e3d5ab5ae7527a5f1c2a13 --- /dev/null +++ b/tilelang/original/src/runtime/runtime.cc @@ -0,0 +1,349 @@ +/*! + * \file tl/runtime/runtime.h + * \brief Runtime functions. + * + */ + +#include "runtime.h" + +#include "../target/cuda.h" +#include +#include + +namespace tvm { +namespace tl { + +#if 1 +// Thread-local storage for restoring the L2 persisting cache limit +static thread_local size_t __tl_prev_persisting_l2_cache_size = 0; +static thread_local bool __tl_prev_persisting_l2_cache_saved = false; +#endif + +#if (CUDA_MAJOR_VERSION >= 12) +template static std::string ArrayToStr(const T *ptr, size_t n) { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < n; i++) { + if (i > 0) + ss << ", "; + ss << ptr[i]; // NOLINT(clang-analyzer-security.ArrayBound) + } + ss << "]"; + return ss.str(); +} + +struct TensorMapArgs { + CUtensorMap *map; + CUtensorMapDataType type; + cuuint32_t tensorRank; + void *globalAddress; + cuuint64_t globalDim[5], globalStride[5]; + cuuint32_t boxDim[5], elementStrides[5]; + CUtensorMapInterleave interleave; + CUtensorMapSwizzle swizzle; + CUtensorMapL2promotion l2Promotion; + CUtensorMapFloatOOBfill oobFill; + + static TensorMapArgs Extract(PackedArgs args) { + TensorMapArgs T; + int idx = 0; + ICHECK(args.size() >= 8); + T.map = reinterpret_cast(args[idx++].cast()); + T.type = static_cast(args[idx++].cast()); + T.tensorRank = static_cast(args[idx++].cast()); + T.globalAddress = args[idx++].cast(); + ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5); + ICHECK(args.size() == static_cast(8 + T.tensorRank * 4)); + for (size_t i = 0; i < T.tensorRank; i++) { + T.globalDim[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank; i++) { + T.globalStride[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank; i++) { + T.boxDim[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank; i++) { + T.elementStrides[i] = args[idx++].cast(); + } + T.interleave = + static_cast(args[idx++].cast()); + T.swizzle = static_cast(args[idx++].cast()); + T.l2Promotion = + static_cast(args[idx++].cast()); + T.oobFill = + static_cast(args[idx++].cast()); + return T; + } + + std::string ToDebugString() { + std::stringstream ss; + ss << "TMA Desc Addr: " << map << '\n' + << "format " << type << '\n' + << "dim " << tensorRank << '\n' + << "gmem_address " << globalAddress << '\n' + << "globalDim " << ArrayToStr(globalDim, tensorRank) << '\n' + << "globalStrides " << ArrayToStr(globalStride, tensorRank) << '\n' + << "boxDim " << ArrayToStr(boxDim, tensorRank) << '\n' + << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << '\n' + << "interleave " << interleave << '\n' + << "swizzle " << swizzle << '\n' + << "l2Promotion " << l2Promotion << '\n' + << "oobFill " << oobFill << '\n'; + return ss.str(); + } +}; + +// set device api +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + // Register using the canonical names defined in runtime.h + refl::GlobalDef().def_packed( + tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) { + TensorMapArgs T = TensorMapArgs::Extract(args); + CUresult result = cuTensorMapEncodeTiled( + T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, + T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, + T.swizzle, T.l2Promotion, T.oobFill); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to initialize the TMA descriptor " << result + << '\n' + << T.ToDebugString(); + } + *ret = static_cast(result); + }); +} + +struct TensorMapIm2ColArgs { + CUtensorMap *map; + CUtensorMapDataType type; + cuuint32_t tensorRank; + void *globalAddress; + cuuint64_t globalDim[5], globalStride[5]; + cuuint32_t elementStrides[5]; + int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3]; + cuuint32_t smem_box_channel, smem_box_pixel; + CUtensorMapInterleave interleave; + CUtensorMapSwizzle swizzle; + CUtensorMapL2promotion l2Promotion; + CUtensorMapFloatOOBfill oobFill; + + static TensorMapIm2ColArgs Extract(PackedArgs args) { + TensorMapIm2ColArgs T; + int idx = 0; + ICHECK(args.size() >= 8); + T.map = reinterpret_cast(args[idx++].cast()); + T.type = static_cast(args[idx++].cast()); + T.tensorRank = static_cast(args[idx++].cast()); + T.globalAddress = args[idx++].cast(); + ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5); + ICHECK(args.size() == static_cast(6 + T.tensorRank * 5)); + for (size_t i = 0; i < T.tensorRank; i++) { + T.globalDim[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank; i++) { + T.globalStride[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank; i++) { + T.elementStrides[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank - 2; i++) { + T.pixelBoxLowerCorner[i] = args[idx++].cast(); + } + for (size_t i = 0; i < T.tensorRank - 2; i++) { + T.pixelBoxUpperCorner[i] = args[idx++].cast(); + } + T.smem_box_pixel = args[idx++].cast(); + T.smem_box_channel = args[idx++].cast(); + T.interleave = + static_cast(args[idx++].cast()); + T.swizzle = static_cast(args[idx++].cast()); + T.l2Promotion = + static_cast(args[idx++].cast()); + T.oobFill = + static_cast(args[idx++].cast()); + return T; + } + + std::string ToDebugString() { + std::stringstream ss; + ss << "TMA Desc Addr: " << map << '\n' + << "format " << type << '\n' + << "dim " << tensorRank << '\n' + << "gmem_address " << globalAddress << '\n' + << "globalDim " << ArrayToStr(globalDim, tensorRank) << '\n' + << "globalStrides " << ArrayToStr(globalStride, tensorRank) << '\n' + << "smem_box_pixel " << smem_box_pixel << '\n' + << "smem_box_channel " << smem_box_channel << '\n' + << "pixelBoxLowerCorner " + << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << '\n' + << "pixelBoxUpperCorner " + << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << '\n' + << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << '\n' + << "interleave " << interleave << '\n' + << "swizzle " << swizzle << '\n' + << "l2Promotion " << l2Promotion << '\n' + << "oobFill " << oobFill << '\n'; + return ss.str(); + } +}; + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) { + TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); + CUresult result = cuTensorMapEncodeIm2col( + T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, + T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, + T.smem_box_channel, T.smem_box_pixel, T.elementStrides, + T.interleave, T.swizzle, T.l2Promotion, T.oobFill); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to initialize the TMA descriptor " << result + << '\n' + << T.ToDebugString(); + } + *ret = static_cast(result); + }); +} + +#endif // (CUDA_MAJOR_VERSION >= 12) + +// +// CUDA L2 Persisting Cache Access Policy Window helpers. +// Exposed as TVM FFI packed functions similar to TMA initialization. +// +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + // Set stream access policy window and adjust persisting L2 cache size + // Args: + // [0]: void* base_ptr (required) + // [1]: int64 num_bytes (required) + // [2]: float hit_ratio (optional, default 0.8) + // [3]: void* stream (optional, default 0 => default stream) + // [4]: int64 l2_limit_bytes (optional, default = num_bytes) + refl::GlobalDef().def_packed( + tl::tvm_cuda_stream_set_access_policy_window, + [](PackedArgs args, Any *ret) { + ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes"; + + void *base_ptr = args[0].cast(); + size_t num_bytes = static_cast(args[1].cast()); + float hit_ratio = 0.8f; + if (args.size() >= 3) { + // Accept double/float + hit_ratio = static_cast(args[2].cast()); + } + CUstream stream = nullptr; + if (args.size() >= 4) { + stream = reinterpret_cast(args[3].cast()); + } + size_t l2_limit_bytes = num_bytes; + if (args.size() >= 5) { + l2_limit_bytes = static_cast(args[4].cast()); + } + + // Clamp requested limit to device capability + CUdevice device; + CUresult result = cuCtxGetDevice(&device); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to get current CUDA device: " << result; + } + int max_persisting = 0; + result = cuDeviceGetAttribute( + &max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE, + device); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: " + << result; + } + if (max_persisting > 0 && + l2_limit_bytes > static_cast(max_persisting)) { + l2_limit_bytes = static_cast(max_persisting); + } + + // Save current limit to restore later + size_t init_persisting_l2_cache_size = 0; + result = cuCtxGetLimit(&init_persisting_l2_cache_size, + CU_LIMIT_PERSISTING_L2_CACHE_SIZE); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to get current persisting L2 cache size limit: " + << result; + } + __tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size; + __tl_prev_persisting_l2_cache_saved = true; + + // Set new limit + result = + cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to set persisting L2 cache size limit: " + << result; + } + + // Apply access policy window to stream + CUstreamAttrValue stream_attribute; + memset(&stream_attribute, 0, sizeof(stream_attribute)); + stream_attribute.accessPolicyWindow.base_ptr = base_ptr; + stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes; + stream_attribute.accessPolicyWindow.hitRatio = hit_ratio; + stream_attribute.accessPolicyWindow.hitProp = + CU_ACCESS_PROPERTY_PERSISTING; + stream_attribute.accessPolicyWindow.missProp = + CU_ACCESS_PROPERTY_STREAMING; + + result = cuStreamSetAttribute(stream, + CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW, + &stream_attribute); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to set stream access policy window: " << result; + } + + *ret = static_cast(result); + }); + + // Reset stream access policy window and restore the previous L2 cache size + // Args: + // [0]: void* stream (optional, default 0) + refl::GlobalDef().def_packed( + tl::tvm_cuda_stream_reset_access_policy_window, + [](PackedArgs args, Any *ret) { + CUstream stream = nullptr; + if (args.size() >= 1) { + stream = reinterpret_cast(args[0].cast()); + } + + CUstreamAttrValue stream_attribute; + memset(&stream_attribute, 0, sizeof(stream_attribute)); + // num_bytes = 0 disables the access policy window on the stream + stream_attribute.accessPolicyWindow.num_bytes = 0; + + CUresult result = cuStreamSetAttribute( + stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW, + &stream_attribute); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to reset stream access policy window: " + << result; + } + + result = cuCtxResetPersistingL2Cache(); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result; + } + + if (__tl_prev_persisting_l2_cache_saved) { + result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, + __tl_prev_persisting_l2_cache_size); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to restore persisting L2 cache size limit: " + << result; + } + __tl_prev_persisting_l2_cache_saved = false; + } + + *ret = static_cast(result); + }); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/runtime/runtime.h b/tilelang/original/src/runtime/runtime.h new file mode 100644 index 0000000000000000000000000000000000000000..4b389fc03e365bd4d9d131ceedb332aa2cdf51be --- /dev/null +++ b/tilelang/original/src/runtime/runtime.h @@ -0,0 +1,28 @@ +/*! + * \file tl/runtime/runtime.h + * \brief Runtime functions. + * + */ + +#ifndef TVM_TL_RUNTIME_RUNTIME_H_ +#define TVM_TL_RUNTIME_RUNTIME_H_ + +namespace tvm { +namespace tl { + +#if (CUDA_MAJOR_VERSION >= 12) +constexpr const char *tvm_tensormap_create_tiled = + "__tvm_tensormap_create_tiled"; +constexpr const char *tvm_tensormap_create_im2col = + "__tvm_tensormap_create_im2col"; +#endif // (CUDA_MAJOR_VERSION >= 12) + +// CUDA stream access policy window helpers +constexpr const char *tvm_cuda_stream_set_access_policy_window = + "__tvm_cuda_stream_set_access_policy_window"; +constexpr const char *tvm_cuda_stream_reset_access_policy_window = + "__tvm_cuda_stream_reset_access_policy_window"; +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_RUNTIME_RUNTIME_H_ diff --git a/tilelang/original/src/support/ffi_aliases.h b/tilelang/original/src/support/ffi_aliases.h new file mode 100644 index 0000000000000000000000000000000000000000..7dbe0b39501e861d040766847d1beb45c82a1f28 --- /dev/null +++ b/tilelang/original/src/support/ffi_aliases.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +using ffi::Array; +using ffi::Function; +using ffi::Map; +using ffi::Optional; +using ffi::String; +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_c_host.cc b/tilelang/original/src/target/codegen_c_host.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f1a70cef86662eabdeb416b2066c2a9e4957fd4 --- /dev/null +++ b/tilelang/original/src/target/codegen_c_host.cc @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_c_host.cc + */ +#include "codegen_c_host.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// For escaping strings embedded into generated C sources +#include "support/str_escape.h" + +namespace tvm { +namespace tl { + +CodeGenCHost::CodeGenCHost() { + module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx); +} + +void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, + bool emit_fwd_func_decl, std::string target_str, + const std::unordered_set &devices) { + emit_asserts_ = emit_asserts; + emit_fwd_func_decl_ = emit_fwd_func_decl; + declared_globals_.clear(); + decl_stream << "// tilelang target: " << target_str << "\n"; + decl_stream << "#define TVM_EXPORTS\n"; + decl_stream << "#include \"tvm/runtime/base.h\"\n"; + decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; + decl_stream << "#include \n"; + // snprintf for richer assert messages with actual values + decl_stream << "#include \n"; + decl_stream << "#include \n"; + CodeGenCHost::InitGlobalContext(); + tvm::codegen::CodeGenC::Init(output_ssa); +} + +void CodeGenCHost::InitGlobalContext() { + decl_stream << "void* " << tvm::ffi::symbol::tvm_ffi_library_ctx + << " = NULL;\n"; +} + +void CodeGenCHost::DefineModuleName() { + decl_stream << "void* " << module_name_ << " = NULL;\n"; +} + +void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, + const tvm::tir::PrimFunc &func) { + return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false); +} + +void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, + const tvm::tir::PrimFunc &func, + bool emit_fwd_func_decl) { + auto global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol); + if (global_symbol) { + function_names_.push_back(global_symbol.value()); + } + + emit_fwd_func_decl_ = emit_fwd_func_decl; + tvm::codegen::CodeGenC::AddFunction(gvar, func); + if (func->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc) && !has_main_func_) { + ICHECK(global_symbol.has_value()) + << "CodeGenCHost: The entry func must have the global_symbol " + "attribute, " + << "but function " << gvar << " only has attributes " << func->attrs; + function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main); + stream << "// CodegenC: NOTE: Auto-generated entry function\n"; + PrintFuncPrefix(stream); + PrintType(func->ret_type, stream); + stream << " " << tvm::ffi::symbol::tvm_ffi_main + << "(void* self, void* args,int num_args, void* result) {\n"; + stream << " return " << static_cast(global_symbol.value()) + << "(self, args, num_args, result);\n"; + stream << "}\n"; + has_main_func_ = true; + } +} + +void CodeGenCHost::GenerateForwardFunctionDeclarations( + tvm::ffi::String global_symbol, const tvm::ffi::Array &arg_types, + const tvm::Type &ret_type) { + if (!emit_fwd_func_decl_) { + return; + } + for (auto &func_already_defined : GetFunctionNames()) { + if (global_symbol == func_already_defined) { + return; + } + } + this->PrintFuncPrefix(fwd_decl_stream); + this->PrintType(ret_type, fwd_decl_stream); + fwd_decl_stream << " " << global_symbol << "("; + for (size_t i = 0; i < arg_types.size(); ++i) { + if (i > 0) { + fwd_decl_stream << ", "; + } + tvm::codegen::CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream); + } + fwd_decl_stream << ");\n"; +} + +void CodeGenCHost::PrintFuncPrefix(std::ostream &os) { // NOLINT(*) + os << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; +} + +void CodeGenCHost::PrintType(tvm::DataType t, std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK_EQ(lanes, 1) << "does not support vector types"; + os << "void*"; + return; + } + if (t.is_void()) { + os << "void"; + return; + } + if (t == tvm::DataType::Bool()) { + os << "bool"; + return; + } + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; + return; + } + } + if (t.is_bfloat16()) { + os << "__bf16"; + return; + } + if (t.is_int() || t.is_uint()) { + if (t.is_uint()) { + os << 'u'; + } + switch (t.bits()) { + case 8: + os << "int8_t"; + break; + case 16: + os << "int16_t"; + break; + case 32: + os << "int32_t"; + break; + case 64: + os << "int64_t"; + break; + case 1: + os << "int32_t"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; + return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to C type"; +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + int lanes = op->dtype.lanes(); + os << "(("; + PrintType(op->dtype, os); + os << ")("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << "))"; +} + +void CodeGenCHost::PrintGetFuncFromBackend( + const std::string &func_name, const std::string &packed_func_name) { + this->PrintIndent(); + this->stream << "if (" << packed_func_name << " == NULL) {\n"; + int packed_func_if_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" + << func_name << "\"" + << ", &" << packed_func_name << ") != 0) {\n"; + int get_func_env_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(get_func_env_scope); + this->PrintIndent(); + this->stream << "}\n"; + this->EndScope(packed_func_if_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +void CodeGenCHost::PrintCallPacked(const tvm::tir::CallNode *op) { + using namespace tvm::tir; + const StringImmNode *func_name = op->args[0].as(); + ICHECK(func_name != nullptr) + << "tvm_call_[c]packed_lowered expects first argument as function name"; + int64_t begin = op->args[2].as()->value; + int64_t end = op->args[3].as()->value; + int64_t num_args = end - begin; + ICHECK_GE(num_args, 0); + + std::string packed_func_name; + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + packed_func_name = GetPackedName(op); + this->PrintGetFuncFromBackend(func_name->value, packed_func_name); + } else { + // directly use the original symbol + ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); + packed_func_name = + tvm::ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; + } + + std::string args_stack = PrintExpr(op->args[1]); + this->PrintIndent(); + std::string result = name_supply_->FreshName("result"); + this->stream << "TVMFFIAny " << result << ";\n"; + this->PrintIndent(); + // must make sure type_index is set to none + this->stream << result << ".type_index = kTVMFFINone;\n"; + this->PrintIndent(); + this->stream << result << ".zero_padding = 0;\n"; + this->PrintIndent(); + this->stream << result << ".v_int64 = 0;\n"; + this->PrintIndent(); + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", "; + } else { + this->stream << "if (" << packed_func_name << "(NULL, "; + } + this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", " + << "&" << result << ") != 0) {\n"; + int func_call_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(func_call_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +std::string CodeGenCHost::GetPackedName(const tvm::tir::CallNode *op) { + using namespace tvm::tir; + const StringImmNode *s = op->args[0].as(); + ICHECK(s != nullptr) + << "tvm_call_packed_lowered expects first argument as function name"; + std::string func_name = s->value; + std::string packed_func_name = func_name + "_packed"; + std::string unique_name; + auto it = declared_globals_.find(packed_func_name); + if (it != declared_globals_.end()) { + unique_name = it->second; + } else { + unique_name = name_supply_->FreshName(packed_func_name); + declared_globals_[packed_func_name] = unique_name; + decl_stream << "static void* " << unique_name << " = NULL;\n"; + } + return unique_name; +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op, + std::ostream &os) { // NOLINT(*) + using namespace tvm::tir; + if (op->op.same_as(builtin::tvm_stack_alloca())) { + std::string stack_name = name_supply_->FreshName("stack"); + const std::string &type = op->args[0].as()->value; + const IntImmNode *num = op->args[1].as(); + ICHECK(num != nullptr); + static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); + size_t unit = sizeof(TVMFFIAny); + size_t size = 0; + if (type == "shape") { + size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit; + } else if (type == "tvm_ffi_any") { + size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit; + } else if (type == "array") { + size = (num->value * sizeof(DLTensor) + unit - 1) / unit; + } else { + LOG(FATAL) << "Unknown stack alloca type " << type; + } + this->PrintIndent(); + this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; + os << stack_name; + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + this->PrintCallPacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + this->PrintCallPacked(op); + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { + this->PrintIndent(); + this->stream << "return -1;\n"; + } else { + tvm::codegen::CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*) + if (emit_asserts_) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + stream << "if (!(" << cond << ")) {\n"; + int assert_if_scope = this->BeginScope(); + { + // Prepare the base error message: allow StringImm or general PrimExpr + const auto *msg_node = op->message.as(); + bool msg_is_literal = (msg_node != nullptr); + std::string esc_msg; + std::string msg_expr; + if (msg_is_literal) { + const std::string &raw_msg = msg_node->value; + esc_msg = tvm::support::StrEscape( + raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, + /*escape_whitespace_special_chars=*/true); + } else { + msg_expr = PrintExpr(op->message); + } + + // Only print expected/got values for equality when message is StringImm + if (msg_is_literal) { + if (const auto *eq = op->condition.as()) { + std::string lhs = PrintExpr(eq->a); + std::string rhs = PrintExpr(eq->b); + PrintIndent(); + stream << "char __tvm_assert_msg_buf[512];\n"; + PrintIndent(); + stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, " + "got: %lld\", \"" + << esc_msg << "\", (long long)(" << lhs << "), (long long)(" + << rhs << "));\n"; + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " + "__tvm_assert_msg_buf);\n"; + } else { + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" + << esc_msg << "\");\n"; + } + } else { + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " << msg_expr + << ");\n"; + } + } + PrintIndent(); + stream << "return -1;\n"; + this->EndScope(assert_if_scope); + PrintIndent(); + stream << "}\n"; + } + this->PrintStmt(op->body); +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintTernaryCondExpr(op, "<", os); +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintTernaryCondExpr(op, ">", os); +} + +template +inline void CodeGenCHost::PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os) { // NOLINT(*) + std::ostringstream temp_a; + VisitExpr(op->a, temp_a); + std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); + std::ostringstream temp_b; + VisitExpr(op->b, temp_b); + std::string b_id = SSAGetID(temp_b.str(), op->b.dtype()); + + os << "((" << a_id << ") " << compare << " (" << b_id << ") " + << "? (" << a_id << ") : (" << b_id << "))"; +} + +} // namespace tl +} // namespace tvm + +namespace tvm { +namespace tl { + +using tvm::codegen::CodeGenSourceBase; +using tvm::codegen::CSourceModuleCreate; +using tvm::ffi::Array; +using tvm::ffi::Map; +using tvm::ffi::Module; +using tvm::ffi::String; + +// Build function that mirrors TVM's host C codegen, registered under a +// TileLang-specific name. +::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, + ::tvm::Target target) { + bool output_ssa = false; + bool emit_asserts = true; + bool emit_fwd_func_decl = true; + + std::unordered_set devices; + if (mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>( + "device_contexts") != nullptr) { + ::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String> device_contexts = + mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>( + "device_contexts") + .value(); + for (auto const &context : device_contexts) { + devices.insert(context.second.data()); + } + } + + CodeGenCHost cg; + cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); + cg.SetConstantsByteAlignment( + target->GetAttr<::tvm::Integer>("constants-byte-alignment").value_or(16)); + + auto is_aot_executor_fn = [](::tvm::tir::PrimFunc const &func) -> bool { + return func->GetAttr<::tvm::Bool>("runner_function", ::tvm::Bool(false)) + .value(); + }; + + std::vector> funcs; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>()) + << "CodegenCHost: Can only take PrimFunc"; + auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func); + funcs.push_back({gvar, prim_func}); + } + + auto sort_key = [&is_aot_executor_fn](const auto &kv) { + return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; + }; + std::sort(funcs.begin(), funcs.end(), + [&sort_key](const auto &kv_a, const auto &kv_b) { + return sort_key(kv_a) < sort_key(kv_b); + }); + + for (const auto &[gvar, prim_func] : funcs) { + cg.DeclareFunction(gvar, prim_func); + } + + for (const auto &[gvar, prim_func] : funcs) { + cg.AddFunction(gvar, prim_func, emit_fwd_func_decl); + } + + std::string code = cg.Finish(); + return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames()); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_c_host.h b/tilelang/original/src/target/codegen_c_host.h new file mode 100644 index 0000000000000000000000000000000000000000..8d54cb4ad9496ef6c86f93b4412f3c15148fe937 --- /dev/null +++ b/tilelang/original/src/target/codegen_c_host.h @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_c_host.h + * \brief Generate C host code (TileLang copy). + */ +#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ + +#include +#include +#include +#include +#include + +#include "target/source/codegen_c.h" +#include "tvm/target/codegen.h" +#include "tvm/tir/expr.h" + +namespace tvm { +namespace tl { + +// TileLang copy of TVM's CodeGenCHost, under the tl namespace. +// Inherits from tvm::codegen::CodeGenC. +class CodeGenCHost : public tvm::codegen::CodeGenC { +public: + CodeGenCHost(); + void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, + std::string target_str, + const std::unordered_set &devices); + + void InitGlobalContext(); + + void AddFunction(const tvm::GlobalVar &gvar, + const tvm::tir::PrimFunc &f) override; + void AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &f, + bool emit_fwd_func_decl); + /*! + * \brief Add functions from the (unordered) range to the current module in a + * deterministic order. This helps with debugging. + * + * \param functions A vector of unordered range of current module. + */ + void AddFunctionsOrdered( + std::vector> functions); + void DefineModuleName(); + + using tvm::codegen::CodeGenC::PrintType; + void PrintType(tvm::DataType t, std::ostream &os) final; // NOLINT(*) + void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*) + + // overload visitor functions + void VisitExpr_(const tvm::tir::BroadcastNode *op, + std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const tvm::tir::CallNode *op, + std::ostream &os) override; // NOLINT(*) + // overload min and max to use the ternary operator, so we don't rely on the + // standard library implementations + void VisitExpr_(const tvm::tir::MinNode *op, + std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const tvm::tir::MaxNode *op, + std::ostream &os) final; // NOLINT(*) + + void VisitStmt_(const tvm::tir::AssertStmtNode *op) final; // NOLINT(*) + + void GenerateForwardFunctionDeclarations( + tvm::ffi::String global_symbol, + const tvm::ffi::Array &arg_types, + const tvm::Type &ret_type) override; + tvm::ffi::Array GetFunctionNames() { + return function_names_; + } + +private: + std::string module_name_; + /* \brief mapping global packed func to the unique name */ + std::unordered_map declared_globals_; + /* \brief names of the functions declared in this module */ + tvm::ffi::Array function_names_; + /*! \brief whether to emit asserts in the resulting C code */ + bool emit_asserts_; + /*! \brief whether to emit forwared function declarations in the resulting C + * code */ + bool emit_fwd_func_decl_; + /*! \brief whether to generate the entry function if encountered */ + bool has_main_func_ = false; + + std::string GetPackedName(const tvm::tir::CallNode *op); + void PrintGetFuncFromBackend(const std::string &func_name, + const std::string &packed_func_name); + void PrintCallPacked(const tvm::tir::CallNode *op); + /*! + * \brief Print ternary conditional operator implementing binary `op` + * Forces the operands to be in SSA form. + * \param op binary operator being expressed + * \param compare string representation of comparison operator + * \param os stream reference to print into + */ + template + inline void PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os); // NOLINT(*) +}; + +} // namespace tl +} // namespace tvm + +#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ diff --git a/tilelang/original/src/target/codegen_cpp.cc b/tilelang/original/src/target/codegen_cpp.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f736bb066810d7d9abc720f72bfa1c3ab4dd873 --- /dev/null +++ b/tilelang/original/src/target/codegen_cpp.cc @@ -0,0 +1,485 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_c_host.cc + */ +#include "codegen_cpp.h" + +#include +#include + +#include +#include +#include + +#include "../op/builtin.h" +#include "../support/ffi_aliases.h" +#include "support/str_escape.h" +#include "target/build_common.h" +#include "target/source/codegen_params.h" + +namespace tvm { +namespace codegen { + +CodeGenTileLangCPP::CodeGenTileLangCPP() { + module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); +} + +void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts, + bool emit_fwd_func_decl, std::string target_str, + const std::unordered_set &devices) { + emit_asserts_ = emit_asserts; + emit_fwd_func_decl_ = emit_fwd_func_decl; + declared_globals_.clear(); + decl_stream << "// tilelang target: " << target_str << "\n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "\n"; + CodeGenC::Init(output_ssa); +} + +void CodeGenTileLangCPP::InitGlobalContext() { + decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n"; +} + +void CodeGenTileLangCPP::DefineModuleName() { + decl_stream << "void* " << module_name_ << " = NULL;\n"; +} + +void CodeGenTileLangCPP::GenerateForwardFunctionDeclarations( + String global_symbol, + + const Array &arg_types, const Type &ret_type) { + if (!emit_fwd_func_decl_) { + return; + } + for (auto &func_already_defined : GetFunctionNames()) { + if (global_symbol == func_already_defined) { + return; + } + } + this->PrintFuncPrefix(fwd_decl_stream); + this->PrintType(ret_type, fwd_decl_stream); + fwd_decl_stream << " " << global_symbol << "("; + for (size_t i = 0; i < arg_types.size(); ++i) { + if (i > 0) { + fwd_decl_stream << ", "; + } + CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream); + } + fwd_decl_stream << ");\n"; +} + +void CodeGenTileLangCPP::PrintFuncPrefix(std::ostream &os) { // NOLINT(*) + os << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; +} + +void CodeGenTileLangCPP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK_EQ(lanes, 1) << "does not support vector types"; + os << "void*"; + return; + } + if (t.is_void()) { + os << "void"; + return; + } + if (t == DataType::Bool()) { + os << "bool"; + return; + } + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; + return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << 'u'; + } + switch (t.bits()) { + case 8: + os << "int8_t"; + break; + case 16: + os << "int16_t"; + break; + case 32: + os << "int32_t"; + break; + case 64: + os << "int64_t"; + break; + case 1: + os << "int32_t"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; + return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to C type"; +} + +void CodeGenTileLangCPP::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + int lanes = op->dtype.lanes(); + os << "(("; + PrintType(op->dtype, os); + os << ")("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << "))"; +} + +void CodeGenTileLangCPP::PrintGetFuncFromBackend( + const std::string &func_name, const std::string &packed_func_name) { + this->PrintIndent(); + this->stream << "if (" << packed_func_name << " == NULL) {\n"; + int packed_func_if_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" + << func_name << "\"" + << ", &" << packed_func_name << ") != 0) {\n"; + int get_func_env_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(get_func_env_scope); + this->PrintIndent(); + this->stream << "}\n"; + this->EndScope(packed_func_if_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name, + int num_args) { + this->PrintIndent(); + std::string ret_val = name_supply_->FreshName("ret_val"); + std::string ret_type_code = name_supply_->FreshName("ret_type_code"); + this->stream << "TVMFFIAny " << ret_val << ";\n"; + this->PrintIndent(); + this->stream << "int " << ret_type_code << ";\n"; + this->PrintIndent(); + this->stream << "if (TVMFuncCall(" << packed_func_name << ", " + << "(TVMFFIAny*) stack_value" + << ", " + << "(int*) stack_tcode" + << ", " << num_args << ", " + << "&" << ret_val << ", " + << "&" << ret_type_code << ") != 0) {\n"; + int func_call_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(func_call_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +void CodeGenTileLangCPP::PrintFuncCallC( + const std::string &packed_func_name, int num_args, + const std::string &resource_handle_name) { + this->PrintIndent(); + std::string ret_val = name_supply_->FreshName("ret_val"); + std::string ret_type_code = name_supply_->FreshName("ret_type_code"); + this->stream << "TVMFFIAny " << ret_val << ";\n"; + this->PrintIndent(); + this->stream << "int " << ret_type_code << ";\n"; + this->PrintIndent(); + + this->stream << "if (" << packed_func_name << "( " + << "(TVMFFIAny*) stack_value " + << ", " + << "(int*) stack_tcode" + << ", " << num_args << ", " + << "&" << ret_val << ", " + << "&" << ret_type_code << ", " << resource_handle_name + << ") != 0){\n"; + + int func_call_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(func_call_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { + // clear previous generated state. + this->InitFuncState(f); + // reserve keywords + ReserveKeywordsAsUnique(); + + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); + std::unordered_set non_restrict; + if (auto opt = + f->GetAttr>(tl::attr::kNonRestrictParams)) { + for (const tir::Var &v : opt.value()) + non_restrict.insert(v.get()); + } + + this->PrintFuncPrefix(stream); + CodeGenC::PrintType(f->ret_type, stream); + this->PrintExtraAttrs(f, stream); + this->stream << " " << static_cast(global_symbol.value()) << "("; + + for (size_t i = 0; i < f->params.size(); ++i) { + tir::Var v = f->params[i]; + std::string vid = AllocVarID(v.get()); + if (i != 0) + stream << ", "; + if (v.dtype().is_handle()) { + // work around for grid constant parameters. + if (auto *ptr = v->type_annotation.as()) { + if (ptr->storage_scope == "grid_constant") { + stream << "__grid_constant__ const "; + CodeGenC::PrintType(ptr->element_type, stream); + stream << ' ' << vid; + continue; + } + } + + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); + } + + CodeGenC::PrintType(GetType(v), stream); + if (auto *ptr = v->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } + + if (no_alias && !non_restrict.count(v.get())) { + PrintRestrict(v, stream); + } + } else { + CodeGenC::PrintType(GetType(v), stream); + } + stream << ' ' << vid; + } + stream << ") {\n"; + this->PreFunctionBody(f); + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; +} + +std::string CodeGenTileLangCPP::GetPackedName(const CallNode *op) { + const StringImmNode *s = op->args[0].as(); + ICHECK(s != nullptr) + << "tvm_call_packed_lowered expects first argument as function name"; + std::string func_name = s->value; + std::string packed_func_name = func_name + "_packed"; + std::string unique_name; + auto it = declared_globals_.find(packed_func_name); + if (it != declared_globals_.end()) { + unique_name = it->second; + } else { + unique_name = name_supply_->FreshName(packed_func_name); + declared_globals_[packed_func_name] = unique_name; + decl_stream << "static void* " << unique_name << " = NULL;\n"; + } + return unique_name; +} + +CodeGenTileLangCPP::FunctionInfo +CodeGenTileLangCPP::GetFunctionInfo(const CallNode *op, + bool has_resource_handle) { + const StringImmNode *s = op->args[0].as(); + ICHECK(s != nullptr) + << "tvm_call_[c]packed_lowered expects first argument as function name"; + int64_t begin = op->args[3].as()->value; + int64_t end = op->args[4].as()->value; + int64_t num_args = end - begin; + ICHECK_GE(num_args, 0); + std::string func_name = s->value; + + if (has_resource_handle) { + const StringImmNode *resource_handle_var = op->args[5].as(); + if (resource_handle_var != nullptr) { + std::string resource_handle_name = resource_handle_var->value; + return {func_name, num_args - 1, resource_handle_name}; + } else { + // The final arg should be "(void*) NULL" to indicate the empty + // resource_handle. + num_args--; + + const CallNode *reinterpret_call = op->args[5].as(); + ICHECK_NE(reinterpret_call, (void *)nullptr) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var " + "from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK_EQ(reinterpret_call->op, builtin::reinterpret()) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var " + "from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK(is_zero(reinterpret_call->args[0])) + << "At CallNode to " << s + << " arg 5: Expect either StringImm naming the " + "resource_handle var from interface API, or " + << "zero; got " << op->args[5]; + } + } + return {func_name, num_args, "NULL"}; +} + +void CodeGenTileLangCPP::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + if (op->op.same_as(builtin::tvm_stack_alloca())) { + std::string stack_name = name_supply_->FreshName("stack"); + const std::string &type = op->args[0].as()->value; + const IntImmNode *num = op->args[1].as(); + ICHECK(num != nullptr); + static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); + size_t unit = sizeof(TVMFFIAny); + size_t size = 0; + if (type == "shape") { + size = (num->value * sizeof(runtime::tvm_index_t) + unit - 1) / unit; + } else if (type == "arg_value") { + size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit; + } else if (type == "arg_tcode") { + size = (num->value * sizeof(int) + unit - 1) / unit; + } else if (type == "array") { + size = (num->value * sizeof(DLTensor) + unit - 1) / unit; + } else { + LOG(FATAL) << "Unknown stack alloca type " << type; + } + this->PrintIndent(); + this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; + os << stack_name; + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + auto function_info = GetFunctionInfo(op, false /* has_resource_handle */); + std::string func_name_packed = GetPackedName(op); + this->PrintGetFuncFromBackend(function_info.func_name, func_name_packed); + this->PrintFuncCall(func_name_packed, function_info.num_args); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + auto function_info = GetFunctionInfo(op, true /* has_resource_handle */); + this->PrintFuncCallC(function_info.func_name, function_info.num_args, + function_info.resource_handle_name); + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { + this->PrintIndent(); + this->stream << "return -1;\n"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenTileLangCPP::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) + if (emit_asserts_) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + stream << "if (!(" << cond << ")) {\n"; + int assert_if_scope = this->BeginScope(); + PrintIndent(); + stream << "TVMAPISetLastError(\"" << op->message.as()->value + << "\");\n"; + PrintIndent(); + stream << "return -1;\n"; + this->EndScope(assert_if_scope); + PrintIndent(); + stream << "}\n"; + } + this->PrintStmt(op->body); +} + +void CodeGenTileLangCPP::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + this->PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + + PrintType(op->dtype, stream); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + + stream << ' ' << vid << '[' << constant_size << "];\n"; + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); +} + +void CodeGenTileLangCPP::VisitExpr_(const MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintTernaryCondExpr(op, "<", os); +} + +void CodeGenTileLangCPP::VisitExpr_(const MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintTernaryCondExpr(op, ">", os); +} + +template +inline void +CodeGenTileLangCPP::PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os) { // NOLINT(*) + std::ostringstream temp_a; + VisitExpr(op->a, temp_a); + std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); + std::ostringstream temp_b; + VisitExpr(op->b, temp_b); + std::string b_id = SSAGetID(temp_b.str(), op->b.dtype()); + + os << "((" << a_id << ") " << compare << " (" << b_id << ") " + << "? (" << a_id << ") : (" << b_id << "))"; +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_cpp.h b/tilelang/original/src/target/codegen_cpp.h new file mode 100644 index 0000000000000000000000000000000000000000..25bb115c824eaedfd3e5aab030e595cbaec714a1 --- /dev/null +++ b/tilelang/original/src/target/codegen_cpp.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_c_host.h + * \brief Generate C host code. + */ +#ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ + +#include +#include +#include +#include +#include + +#include "target/source/codegen_c.h" +#include "tvm/target/codegen.h" +#include "tvm/tir/expr.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangCPP : public CodeGenC { +public: + CodeGenTileLangCPP(); + void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, + std::string target_str, + const std::unordered_set &devices); + + void InitGlobalContext(); + // Override this as a work around for non tvm runtime code generations + void AddFunction(const PrimFunc &f); + + /*! + * \brief Add functions from the (unordered) range to the current module in a + * deterministic order. This helps with debugging. + * + * \param functions A vector of unordered range of current module. + */ + void AddFunctionsOrdered( + std::vector> functions); + void DefineModuleName(); + + using CodeGenC::PrintType; + void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) + void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*) + + // overload visitor functions + void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*) + // overload min and max to use the ternary operator, so we don't rely on the + // standard library implementations + void VisitExpr_(const MinNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const MaxNode *op, std::ostream &os) final; // NOLINT(*) + + void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*) + void VisitStmt_(const AllocateNode *op) final; // NOLINT(*) + + void GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array &arg_types, + const Type &ret_type) override; + ffi::Array GetFunctionNames() { return function_names_; } + +private: + /* \brief Internal structure to store information about function calls */ + struct FunctionInfo { + /* \brief function name */ + std::string func_name; + /* number of arguments required by the function */ + int64_t num_args; + /* \brief name of resource_handle to pass */ + std::string resource_handle_name; + }; + std::string module_name_; + /* \brief mapping global packed func to the unique name */ + std::unordered_map declared_globals_; + /* \brief names of the functions declared in this module */ + ffi::Array function_names_; + /*! \brief whether to emit asserts in the resulting C code */ + bool emit_asserts_; + /*! \brief whether to emit forward function declarations in the resulting C + * code */ + bool emit_fwd_func_decl_; + + FunctionInfo GetFunctionInfo(const CallNode *op, bool has_resource_handle); + std::string GetPackedName(const CallNode *op); + void PrintGetFuncFromBackend(const std::string &func_name, + const std::string &packed_func_name); + void PrintFuncCall(const std::string &packed_func_name, int num_args); + void PrintFuncCallC(const std::string &packed_func_name, int num_args, + const std::string &resource_handle_name); + + /*! + * \brief Print ternary conditional operator implementing binary `op` + * Forces the operands to be in SSA form. + * \param op binary operator being expressed + * \param compare string representation of comparison operator + * \param os stream reference to print into + */ + template + inline void PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os); // NOLINT(*) +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ diff --git a/tilelang/original/src/target/codegen_cuda.cc b/tilelang/original/src/target/codegen_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..657871d8fb3c36e2c7aa20a2e5a490e49fc05958 --- /dev/null +++ b/tilelang/original/src/target/codegen_cuda.cc @@ -0,0 +1,3582 @@ +/*! + * \file target/codegen.cc + */ + +#include "codegen_cuda.h" +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "./ptx.h" +#include "arith/pattern_match.h" + +namespace tvm { +namespace codegen { +using namespace tvm::tl::codegen; +using namespace ffi; + +struct CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float()) { + switch (t.bits()) { + case 64: + return name; + case 32: + return name + 'f'; + case 16: { + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } + } + default: + return ""; + } + } else if (t.is_bfloat16()) { + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } + } else if (t.is_int() || t.is_uint()) { + switch (t.bits()) { + case 32: + return "__" + name; + case 64: + return "__" + name + "ll"; + default: + return ""; + } + } + return ""; + } +}; + +struct CUDAFastMath : public CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float() && t.bits() == 32) { + return "__" + name + 'f'; + } else { + return CUDAMath::operator()(t, name); + } + return ""; + } +}; + +struct CUDAFastMathTan : public CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float()) { + switch (t.bits()) { + case 64: + return name; + // `__tanf` seems to produce some values too deviant from numpy tan + // version. So, let's use just `tanf` instead. + case 32: + return name + 'f'; + case 16: + return 'h' + name; + default: + return ""; + } + } + return ""; + } +}; + +struct CUDAIEEEMath { + std::string operator()(DataType t, std::string name, + std::string rounding_mode) const { + if (t.is_float() && t.bits() == 32) { + return "__" + name + "_" + rounding_mode; + } else if (t.is_float() && t.bits() == 64) { + return "__d" + name + "_" + rounding_mode; + } + return ""; + } +}; + +static std::string GetTileLangFP8Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "_2"; + } else if (lanes == 4) { + vec = "_4"; + } else if (lanes == 8) { + vec = "_8"; + } else if (lanes == 16) { + vec = "_16"; + } else if (lanes == 32) { + vec = "_32"; + } else { + LOG(FATAL) + << "Only support scalar and vector types of width (2, 4, 8, 16, 32) " + "for FP8"; + } + if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() || + type.is_float8_e4m3()) { + stream << "fp8_e4" << vec << "_t"; + } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz()) { + stream << "fp8_e5" << vec << "_t"; + } else if (type.is_float8_e8m0fnu()) { + stream << "fp8_e8" << vec << "_t"; + } else { + LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type; + } + return stream.str(); +} + +std::string GetTileLangFP6Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "x2"; + } else if (lanes == 4) { + vec = "x4"; + } else if (lanes == 8) { + vec = "x8"; + } else if (lanes == 16) { + vec = "x16"; + } else { + LOG(FATAL) + << "Only support scalar and vector types of width (2, 4) for FP6"; + } + stream << "__nv_fp6"; + std::string suffix; + if (type.code() == DataType::kFloat6_e2m3fn) { + suffix = "_e2m3"; + } else if (type.code() == DataType::kFloat6_e3m2fn) { + suffix = "_e3m2"; + } else { + LOG(FATAL) << "Unsupported FP6 type in CUDA codegen"; + } + stream << vec << suffix; + return stream.str(); +} + +std::string GetTileLangFP4Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "_2"; + } else if (lanes == 4) { + vec = "_4"; + } else if (lanes == 8) { + vec = "_8"; + } else if (lanes == 16) { + vec = "_16"; + } else if (lanes == 32) { + vec = "_32"; + } else if (lanes == 64) { + vec = "_64"; + } else { + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16, " + "32, 64) for FP4"; + } + + std::string suffix; + if (type.code() == DataType::kFloat4_e2m1fn) { + suffix = "_e2"; + } else { + LOG(FATAL) << "Unsupported FP4 type in CUDA codegen"; + } + + stream << "fp4" << suffix << vec << "_t"; + return stream.str(); +} + +CodeGenTileLangCUDA::CodeGenTileLangCUDA() { + restrict_keyword_ = "__restrict__"; + vid_global_barrier_state_ = + name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state); + vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect"); + ICHECK_EQ(vid_global_barrier_state_, + runtime::symbol::tvm_global_barrier_state); +} + +void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) { + os << "extern \"C\" __global__ "; +} + +class LaunchConfigExtractor : public tir::StmtVisitor { +private: + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->var->name_hint == "threadIdx.x" || + iv->thread_tag == "threadIdx.x") { + threadIdx_x_ext = op->value; + } else if (iv->var->name_hint == "threadIdx.y" || + iv->thread_tag == "threadIdx.y") { + threadIdx_y_ext = op->value; + } else if (iv->var->name_hint == "threadIdx.z" || + iv->thread_tag == "threadIdx.z") { + threadIdx_z_ext = op->value; + } + } + StmtVisitor::VisitStmt_(op); + } + +public: + PrimExpr threadIdx_x_ext = Integer(1); + PrimExpr threadIdx_y_ext = Integer(1); + PrimExpr threadIdx_z_ext = Integer(1); +}; + +void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f) { + LaunchConfigExtractor extractor; + extractor(f->body); + arith::Analyzer analyzer; + PrimExpr threadIdx_ext = + analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * + extractor.threadIdx_z_ext); + if (const IntImmNode *const threadIdx_ext_int = + threadIdx_ext.as()) { + if (threadIdx_ext_int->value == 1) { + // unable to extract the number of threads per block, hence directly + // return + return; + } + stream << " __launch_bounds__(" << threadIdx_ext_int->value << ", 1)"; + } +} + +std::string CodeGenTileLangCUDA::Finish() { + if (need_mma_h_) { + decl_stream << "#include \n"; + } + if (need_mma_instruction_h_) { + decl_stream << "#include \n"; + } + if (need_wgmma_instruction_h_) { + decl_stream << "#include \n"; + } + if (need_tcgen05mma_instruction_h_) { + decl_stream << "#include \n"; + } + if (need_mma_sm70_instruction_h_) { + decl_stream << "#include \n"; + } + if (need_tcgen05_common_h_) { + decl_stream << "#include \n"; + } + if (enable_fp8_) { + decl_stream << "#include \n"; + } + if (enable_fp4_) { + decl_stream << "#include \n"; + } + + if (need_math_constants_h_) { + decl_stream << "#include \n"; + } + + if (need_cooperative_groups_) { + decl_stream << "#include \n"; + } + + if (need_curand_kernel_h_) { + decl_stream << "#include \n"; + } + + decl_stream << "#include \n"; + if (enable_sparse_gemm_) { + decl_stream << "#include \n"; + } + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#ifdef ENABLE_BF16\n"; + decl_stream << "#include \n"; + decl_stream << "#endif\n"; + + if (need_global_barrier_) { + decl_stream << "__device__ unsigned " << vid_global_barrier_state_ + << " = 0;\n"; + } + decl_stream << "\n"; + + return CodeGenC::Finish(); +} + +void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) { + if (op->kind == tir::ForKind::kUnrolled) { + PrintIndent(); + if (unroll_factor.count(op->loop_var.get())) { + stream << "#pragma unroll " + << PrintExpr(unroll_factor[op->loop_var.get()]) << "\n"; + } else { + stream << "#pragma unroll\n"; + } + } + std::string extent = + PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + std::string start = PrintExpr(op->min); + stream << "for ("; + PrintType(op->loop_var.dtype(), stream); + stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent + << "; ++" << vid << ") {\n"; + int for_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(for_scope); + PrintIndent(); + stream << "}\n"; +} + +void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + var_idmap_[iv->var.get()] = + CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); +} + +void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK(t.is_scalar()) << "do not yet support vector types"; + os << "void*"; + return; + } + + if (t.is_void()) { + os << "void"; + return; + } + + if (t == tl::cuTensorMapType()) { + os << "CUtensorMap"; + return; + } + + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: + enable_fp16_ = true; + if (t.is_scalar()) { + os << "half_t"; + } else if (lanes <= 8) { + // Emit CUDA code to access fp16 vector elements. + // + // half4 is stored as uint2 + // + // h4.x is emitted as *(half2*)(&(u2.x)).x + // h4.y is emitted as *(half2*)(&(u2.x)).y + // h4.z is emitted as *(half2*)(&(u2.y)).x + // h4.w is emitted as *(half2*)(&(u2.y)).y + // + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else if (lanes <= 16) { + ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half " + "type of more than 8 lanes"; + os << "ulonglong" << lanes / 4; + } else { + fail = true; + } + break; + case 32: + if (lanes <= 4) { + os << "float"; + } else if (lanes <= 8) { + // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8. + // + // float8 is stored as ulonglong4 + // + // f8.v1 is emitted as *(float2*)(&(ul4.x)).x + // f8.v2 is emitted as *(float2*)(&(ul4.x)).y + // + ICHECK_EQ(lanes % 2, 0) + << "only support even lane for float type with lanes > 4"; + os << "ulonglong" << lanes / 2; + } else { + fail = true; + } + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; + } + if (!fail && (t.is_scalar() || t.bits() == 16)) + return; + if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) + return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_bfloat16()) { + enable_bf16_ = true; + if (t.is_scalar()) { + os << "bfloat16_t"; + } else if (lanes <= 8) { + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else if (lanes <= 16) { + ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half type " + "of more than 8 lanes"; + os << "ulonglong" << lanes / 4; + } else { + fail = true; + } + if (!fail) + return; + } else if (t.is_float8()) { + enable_fp8_ = true; + os << GetTileLangFP8Type(t); + return; + } else if (t.is_float6()) { + enable_fp6_ = true; + if (t.lanes() <= 4) { + os << GetTileLangFP6Type(t); + } + return; + } else if (t.is_float4()) { + enable_fp4_ = true; + if (t.lanes() <= 64) { + os << GetTileLangFP4Type(t); + } else { + fail = true; + } + return; + } else if (t == DataType::Bool()) { + os << "bool"; + return; + } else if (t.is_vector_bool()) { + // CUDA does not support bool vectors. + // Use ushort vectors to represent instead. + int n = t.lanes(); + if (n <= 4) { + os << "ushort" << n; + return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << "u"; + } + switch (t.bits()) { + case 1: { + if (t.is_scalar()) { + os << "int"; + return; + } else if (t.lanes() == 8) { + os << "int8_t"; + return; + } else if (t.lanes() == 16) { + os << "int16_t"; + return; + } else if (t.lanes() == 32) { + os << "int"; + return; + } else { + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + } + } + case 4: { + if (t.is_scalar()) { + os << "int"; + return; + } else if (t.lanes() == 4) { + os << "int16_t"; + return; + } else if (t.lanes() == 8) { + // directly 8 4-bit int in integer. + os << "int"; + return; + } else if (t.lanes() == 16) { + os << "int2"; + return; + } else if (t.lanes() == 32) { + os << "int4"; + return; + } else if (t.lanes() == 64) { + os << "int8"; + return; + } else { + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + } + } + case 8: { + if (t.lanes() == 4) { + // directly 4 8 bit int in integer. + enable_int8_ = true; + + // We use int for int8x4 instead of char4 because using char4 is + // likely to produce extra instructions to pack four int8 elements + // into 32-bit data. + os << "int"; + return; + } else if (t.lanes() == 8) { + enable_int8_ = true; + os << "int2"; + return; + } else if (t.lanes() == 16) { + enable_int8_ = true; + os << "int4"; + return; + } else if (t.lanes() == 32) { + enable_int8_ = true; + os << "longlong4"; + return; + } else if (!t.is_uint() && t.is_scalar()) { + os << "signed char"; + break; + } else { + os << "char"; + break; + } + } + case 16: { + if (t.is_scalar()) { + os << "short"; + } else if (t.lanes() <= 4) { + os << "short" << lanes; + } else if (t.lanes() <= 8) { + // Emit CUDA code to access int16 vector elements. + // + // short4 is stored as int2 + // + // s4.x is emitted as *(short2*)(&(i2.x)).x + // s4.y is emitted as *(short2*)(&(i2.x)).y + // s4.z is emitted as *(short2*)(&(i2.y)).x + // s4.w is emitted as *(short2*)(&(i2.y)).y + // + ICHECK_EQ(t.lanes() % 2, 0) + << "only support even lane for shorT type with lanes > 4"; + os << "int" << t.lanes() / 2; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + case 32: { + if (t.is_scalar()) { + os << "int"; + } else if (t.lanes() <= 4) { + os << "int" << t.lanes(); + } else if (t.lanes() <= 8) { + // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. + // + // int8 is stored as longlong4 + // + // i8.v1 is emitted as *(int2*)(&(l4.x)).x + // i8.v2 is emitted as *(int2*)(&(l4.x)).y + // + ICHECK_EQ(lanes % 2, 0) + << "only support even lane for int32 type with lanes > 4"; + os << "longlong" << lanes / 2; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + case 64: { + if (t.is_scalar()) { + os << "int64_t"; + } else if (t.lanes() == 2) { + os << "longlong2"; + } else if (t.lanes() == 3) { + os << "longlong3"; + } else if (t.lanes() == 4) { + os << "longlong4"; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + default: + fail = true; + break; + } + if (!fail && lanes == 1) { + return; + } + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; +} + +void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, + PrimExpr lhs, PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + // Declare the result. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + int ssa_scope = BeginScope(); + { + // Unpack into individual ops. + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << op; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore(sret, t, i, value_temp.str()); + } + } + EndScope(ssa_scope); + os << sret; +} + +void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, + int i, + std::ostream &os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + + static const char access[] = {'x', 'y', 'z', 'w'}; + ICHECK(i >= 0 && i < 256 / t.bits()) + << "i: " << i << " t: " << t << " t.bits(): " << t.bits() + << " t.lanes(): " << t.lanes(); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + std::string type_name = t.is_int() ? "char" : "unsigned char"; + if (t.lanes() == 2 || t.lanes() == 3) { + os << vec << "." << access[i % t.lanes()]; + } else if (t.lanes() <= 16) { + std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); + os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; + } else { + ICHECK(t.lanes() == 32); + std::string ac = vec + "." + access[i / 8]; + os << "((" << type_name << ")(" << ac << " >> " << i % 8 * 8 << "))"; + } + } else if (t.is_float16()) { + if (t.lanes() <= 8) { + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; + } else { + os << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + " + << (i / 2 % 2) << ")->" << access[i % 2]; + } + } else if (t.is_bfloat16()) { + if (t.lanes() <= 8) { + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; + } else { + os << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4] << "))) + " + << (i / 2 % 2) << ")->" << access[i % 2]; + } + } else if (t.is_float8()) { + os << vec; + // fp8_e5_32_t + if (t.lanes() >= 32) + os << "." << access[i / 16]; + // fp8_e5_16_t + if (t.lanes() >= 16) + os << "." << access[(i % 16) / 8]; + // fp8_e5_8_t + if (t.lanes() >= 8) + os << "." << access[(i % 8) / 4]; + // fp8_e5_4_t or fp8_e5_2_t + os << "." << access[i % 4]; + } else if (t.is_float4_e2m1fn()) { + os << vec; + // fp4_e2_64_t + if (t.lanes() >= 64) + os << "." << access[i / 32]; + // fp4_e2_32_t + if (t.lanes() >= 32) + os << "." << access[(i % 32) / 16]; + // fp4_e2_16_t + if (t.lanes() >= 16) + os << "." << access[(i % 16) / 8]; + // fp4_e2_8_t + if (t.lanes() >= 8) + os << "." << access[(i % 8) / 4]; + // fp4_e2_4_t or fp4_e2_2_t + os << "." << access[i % 4]; + } else if (t.lanes() > 4 && t.lanes() <= 8) { + std::string type_name; + if (t.bits() == 16) { + if (t.is_int()) { + type_name = "short"; + } else if (t.is_uint()) { + type_name = "ushort"; + } + } else if (t.bits() == 32) { + if (t.is_int()) { + type_name = "int"; + } else if (t.is_uint()) { + type_name = "uint"; + } else if (t.is_float()) { + type_name = "float"; + } + } + ICHECK(!type_name.empty()); + os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] + << ")))->" << access[i % 2]; + } else { + os << vec << "." << access[i]; + } +} + +void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, + int i, const std::string &value) { + this->PrintIndent(); + static const char access[] = {'x', 'y', 'z', 'w'}; + ICHECK(i >= 0 && i < 256 / t.bits()); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (t.lanes() == 2 || t.lanes() == 3) { + stream << vec << '.' << access[i % t.lanes()] << "=" + << "(" << value << ");\n"; + } else if (t.lanes() <= 16) { + std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); + stream << ac << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; + } + stream << "(" << value << " << " << i % 4 * 8 << ");\n"; + } else { + ICHECK(t.lanes() == 32); + std::string ac = vec + "." + access[i / 8]; + stream << ac << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << ac << " & ~(0x000000ff << " << i % 8 * 8 << ") |"; + } + stream << "(" << value << " << " << i % 8 * 8 << ");\n"; + } + } else if (t.is_float16()) { + if (t.lanes() <= 8) { + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2] << " = " << value << ";\n"; + } else { + stream << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + " + << (i / 2 % 2) << ")->" << access[i % 2] << " = " << value + << ";\n"; + } + } else if (t.is_bfloat16()) { + if (t.lanes() <= 8) { + stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2] << " = " << value << ";\n"; + } else { + stream << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4] + << "))) + " << (i / 2 % 2) << ")->" << access[i % 2] << " = " + << value << ";\n"; + } + } else if (t.is_float8()) { + stream << vec; + // fp8_e5_32_t + if (t.lanes() >= 32) + stream << "." << access[i / 16]; + // fp8_e5_16_t + if (t.lanes() >= 16) + stream << "." << access[(i % 16) / 8]; + // fp8_e5_8_t + if (t.lanes() >= 8) + stream << "." << access[(i % 8) / 4]; + // fp8_e5_4_t or fp8_e5_2_t + stream << "." << access[i % 4] << " = " << value << ";\n"; + } else if (t.lanes() > 4 && t.lanes() <= 8) { + std::string type_name; + if (t.bits() == 16) { + if (t.is_int()) { + type_name = "short"; + } else if (t.is_uint()) { + type_name = "ushort"; + } + } else if (t.bits() == 32) { + if (t.is_int()) { + type_name = "int"; + } else if (t.is_uint()) { + type_name = "uint"; + } else if (t.is_float()) { + type_name = "float"; + } + } + ICHECK(!type_name.empty()); + stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] + << ")))->" << access[i % 2] << " = " << value << ";\n"; + } else if (t.is_float4_e2m1fn()) { + stream << vec; + // fp4_e2_64_t + if (t.lanes() >= 64) + stream << "." << access[i / 32]; + // fp4_e2_32_t + if (t.lanes() >= 32) + stream << "." << access[(i % 32) / 16]; + // fp4_e2_16_t + if (t.lanes() >= 16) + stream << "." << access[(i % 16) / 8]; + // fp4_e2_8_t + if (t.lanes() >= 8) + stream << "." << access[(i % 8) / 4]; + // fp4_e2_4_t or fp4_e2_2_t + stream << "." << access[i % 4] << " = " << value << ";\n"; + } else { + stream << vec << "." << access[i] << " = " << value << ";\n"; + } +} + +void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) { + auto args = op->args; + const std::string &sync = args[0].as()->value; + if (sync == "warp") { + // DO nothing. + } else if (sync == "shared" || sync == "shared.dyn") { + this->PrintIndent(); + if (args.size() == 1) { + this->stream << "__syncthreads();\n"; + } else if (args.size() == 2) { + auto barrier_id = args[1].as()->value; + this->stream << "tl::__sync_thread_partial<" << barrier_id << ">();\n"; + } else if (args.size() == 3) { + auto barrier_id = args[1].as()->value; + auto thread_count = args[2].as()->value; + this->stream << "tl::__sync_thread_partial<" << barrier_id << ", " + << thread_count << ">();\n"; + } else { + LOG(FATAL) << "Invalid number of arguments for storage sync: " + << args.size(); + } + } else if (sync == "global") { + if (!need_global_barrier_) { + need_global_barrier_ = true; + } + // global synchronizer + std::string is_load = PrintExpr(op->args[1]); + std::string num_blocks = PrintExpr(op->args[2]); + this->PrintIndent(); + // In theory only threadfence is needed + // but we observed problems with only threadfence + this->stream << "__threadfence_system();\n"; + this->PrintIndent(); + this->stream << "if (" << is_load << ") {\n"; + int wb = this->BeginScope(); + this->PrintIndent(); + this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; + this->PrintIndent(); + std::string ptr = name_supply_->FreshName("pf"); + this->stream << "volatile unsigned* " << ptr << " = &" + << vid_global_barrier_state_ << ";\n"; + this->PrintIndent(); + this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; + this->PrintIndent(); + this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ + << ");\n"; + this->EndScope(wb); + this->PrintIndent(); + this->stream << "}\n"; + this->PrintIndent(); + this->stream << "__syncthreads();\n"; + } +} + +void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope, + std::ostream &os) { // NOLINT(*) + ICHECK_NE(scope, "global") + << "Cannot allocate global memory when targeting CUDA. You must pass " + "all global arrays as input instead"; + if (scope == "shared" || scope == "shared.barrier") { + os << "__shared__ "; + } else if (scope == "shared.dyn") { + os << "extern __shared__ __align__(1024) "; + } +} + +std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, + DataType target) { + if (from == target) + return value; + std::ostringstream os; + os << "(("; + this->PrintType(target, os); + os << ")"; + if (from.is_float16() && (target.is_int() || target.is_uint()) && + target.bits() == 8) { + os << "("; + if (target.is_uint()) { + os << "u"; + } + os << "int)"; + } + if ((from.is_float16() || from.is_bfloat16()) && target.is_float8()) { + os << "(float)"; + } + os << value << ")"; + return os.str(); +} + +void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); + + // Emit simple C-style type conversion. + if (from_ty.is_scalar()) + return CodeGenC::VisitExpr_(op, os); + + // We could emit make_float4 like calls, but the emitted code looks + // too compact to read. Emit this as vectorized unary ops. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(target_ty, stream); + stream << ' ' << sret << ";\n"; + std::string src = SSAGetID(PrintExpr(op->value), from_ty); + + // Handle conversion between float16 and float32 + if (from_ty.is_float16() && target_ty.is_float()) { + // Use __half22float2 for vectorized conversion (half2 -> float2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // half2 -> float2 + PrintIndent(); + stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // half4 -> float4 + PrintIndent(); + stream << "((float2*)(&" << sret << "))[0] = " + << "__half22float2(*(half2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[1] = " + << "__half22float2(*((half2*)(&(" << src << "))+1));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // half8 -> float8 + PrintIndent(); + stream << "((float2*)(&" << sret << "))[0] = " + << "__half22float2(*(half2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[1] = " + << "__half22float2(*((half2*)(&(" << src << "))+1));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[2] = " + << "__half22float2(*((half2*)(&(" << src << "))+2));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[3] = " + << "__half22float2(*((half2*)(&(" << src << "))+3));\n"; + os << sret; + return; + } + } else if (from_ty.is_float() && target_ty.is_float16()) { + // Use __float22half2_rn for vectorized conversion (float2 -> half2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // float2 -> half2 + PrintIndent(); + stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&(" + << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // float4 -> half4 + PrintIndent(); + stream << "((half2*)(&" << sret << "))[0] = " + << "__float22half2_rn(*(float2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "((half2*)(&" << sret << "))[1] = " + << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // float8 -> half8 + PrintIndent(); + stream << "((half2*)(&" << sret << "))[0] = " + << "__float22half2_rn(*(float2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "((half2*)(&" << sret << "))[1] = " + << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n"; + PrintIndent(); + stream << "((half2*)(&" << sret << "))[2] = " + << "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n"; + PrintIndent(); + stream << "((half2*)(&" << sret << "))[3] = " + << "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n"; + os << sret; + return; + } + } + + // Handle conversion between bfloat16 and float32 + if (from_ty.is_bfloat16() && target_ty.is_float()) { + // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // bfloat162 -> float2 + PrintIndent(); + stream << sret + << " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" + << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // bfloat162x2 -> float4 + PrintIndent(); + stream << "((float2*)(&" << sret << "))[0] = " + << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" + << src << ")));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[1] = " + << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" + << src << "))+1));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // bfloat162x4 -> float8 + PrintIndent(); + stream << "((float2*)(&" << sret << "))[0] = " + << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" + << src << ")));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[1] = " + << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" + << src << "))+1));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[2] = " + << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" + << src << "))+2));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[3] = " + << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" + << src << "))+3));\n"; + os << sret; + return; + } + } else if (from_ty.is_float() && target_ty.is_bfloat16()) { + // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // float2 -> bfloat162 + PrintIndent(); + stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret + << ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // float4 -> bfloat162x2 + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = " + << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = " + << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // float8 -> bfloat162x4 + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = " + << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = " + << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n"; + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = " + << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n"; + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = " + << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n"; + os << sret; + return; + } + } + + // Handle conversion from float32 to float8 (E4M3/E5M2) + if (from_ty.is_float() && (target_ty.is_float8())) { + bool target_type_is_e4m3 = target_ty.is_float8_e4m3() || + target_ty.is_float8_e4m3fn() || + target_ty.is_float8_e4m3fnuz(); + // FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion + // (float2 -> fp8x2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // float2 -> fp8x2 + PrintIndent(); + stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret + << ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast(&(" + << src << ")), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // float4 -> fp8x4 + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " + << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src + << ")), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " + << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src + << "))+1), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // float8 -> fp8x8 + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " + << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src + << ")), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " + << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src + << "))+1), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = " + << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src + << "))+2), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = " + << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src + << "))+3), __NV_SATFINITE, " + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; + os << sret; + return; + } + } + + if (from_ty.is_float8() && target_ty.is_float()) { + bool from_type_is_e4m3 = from_ty.is_float8_e4m3() || + from_ty.is_float8_e4m3fn() || + from_ty.is_float8_e4m3fnuz(); + // FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion + // (fp8x2 -> float2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // fp8x2 -> float2 + PrintIndent(); + stream << "*reinterpret_cast(&(" << sret + << ")) = " + "__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_" + "t*>(&(" + << src << ")), " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // fp8x4 -> float4 + PrintIndent(); + stream << "*(float2*)(&" << sret << ") = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+1) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // fp8x8 -> float8 + PrintIndent(); + stream << "*(float2*)(&" << sret << ") = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+1) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+2) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[2], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+3) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[3], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } + } + + // Fallback: elementwise cast + for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { + std::ostringstream val; + val << "("; + PrintType(target_ty.element_of(), val); + val << ")("; + PrintVecElemLoad(src, from_ty, i, val); + val << ")"; + PrintVecElemStore(sret, target_ty, i, val.str()); + } + + os << sret; +} + +void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) { + // TODO(wt): Consider vectorized reduction and impl for other dtypes + DataType t = op->dtype; + + // Standard min/max functions don't support bfloat16 or float16 + if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) { + os << "cutlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) + << ")"; + return; + } + + // For float32 and float64 scalar, use standard min functions + if (t.is_float() && t.is_scalar()) { + if (t.bits() == 32 || t.bits() == 64) { + os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + return; + } + } + + // For all other scalar types (int, uint), use default implementation + CodeGenC::VisitExpr_(op, os); +} + +void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) { + // TODO(wt): Consider vectorized reduction and impl for other dtypes + DataType t = op->dtype; + + // Standard min/max functions don't support bfloat16 or float16 + if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) { + os << "cutlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) + << ")"; + return; + } + + // For float32 and float64 scalar, use standard max functions + if (t.is_float() && t.is_scalar()) { + if (t.bits() == 32 || t.bits() == 64) { + os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + return; + } + } + + // For all other scalar types (int, uint), use default implementation + CodeGenC::VisitExpr_(op, os); +} + +void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, + const Array &args, + bool skip_first_arg, + std::ostream &os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + if (ret_dtype.is_fixed_length_vector()) { + // + // Emit an unsupported vector call + // + // v = intrin_f((float4*)A[0], (float4*)B[0]) + // + // as + // + // float4 __ret; + // { + // float4 __arg0 = ((float4*)A)[0]; + // float4 __arg1 = ((float4*)B)[0]; + // __ret.x = intrin_f(__arg0.x, __arg1.x); + // __ret.y = intrin_f(__arg0.y, __arg1.y); + // __ret.z = intrin_f(__arg0.z, __arg1.z); + // __ret.w = intrin_f(__arg0.w, __arg1.w); + // } + // v = __ret; + // + // Declare the result vector. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(ret_dtype, stream); + stream << ' ' << sret << ";\n"; + { + // Load arguments. + std::vector sargs; + size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + sargs.push_back(std::move(val)); + } + + // Emit a scalar call for each lane. + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j > 0) + scall << ", "; + PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + } + scall << ")"; + PrintVecElemStore(sret, ret_dtype, i, scall.str()); + } + } + os << sret; + } else { + CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, + os); + } +} + +// Print a reference expression to a buffer. +std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, + const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + std::ostringstream os; + std::string vid = GetVarID(buffer_var); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + // bool is_vol = IsVolatile(buffer_var); + // always false for tl cutlass backend. + bool is_vol = false; + + auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { + std::ostringstream ptr_os; + ptr_os << "("; + if (is_vol) { + ptr_os << "volatile "; + } + if (!scope.empty() && IsScopePartOfType()) { + PrintStorageScope(scope, ptr_os); + } + PrintType(pointed_to, ptr_os); + ptr_os << "*)"; + return ptr_os.str(); + }; + + DataType buffer_element_dtype = buffer->dtype; + + std::string buffer_str = vid; + if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { + std::stringstream temp; + temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; + buffer_str = temp.str(); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + if (scope == "local.var" || scope.find("local.descriptor") == 0) { + os << vid; + return os.str(); + } + std::string index_str = PrintExpr(index); + if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) { + // This is a special case, because CodegenCUDA::PrintType() + // returns "int" for bool and for 4-bit integers. In most cases, + // we divide by the number of lanes to determine the index. + // However, the backing type for scalar int4 and scalar bool is + // int32. Therefore, we need to divide by the ratio of their + // sizes in that case. + int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); + + os << "*(" + << "(" << ptr_cast(t) << vid << ")" + << " + " << index_str << " / " << div_factor << ")"; + } else if (t == buffer_element_dtype) { + os << buffer_str << "[" << index_str << "]"; + } else { + os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; + } + + return os.str(); +} + +std::string CodeGenTileLangCUDA::GetVecLoad(DataType t, + const BufferNode *buffer, + PrimExpr base) { + const VarNode *buffer_var = buffer->data.get(); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + + if (scope != "global" || t.bits() * t.lanes() <= 128) { + return this->CodeGenC::GetVecLoad(t, buffer, base); + } + ICHECK_EQ(t.bits() * t.lanes(), 256) + << "Unsupported vector load size: " << t.bits() * t.lanes(); + auto buffer_ref = this->GetBufferRef(t, buffer, base); + std::ostringstream os; + os << "tl::ld_global_256(&(" << buffer_ref << "))"; + return os.str(); +} + +void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t, + PrimExpr base, + const std::string &value) { + const VarNode *buffer_var = buffer->data.get(); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + + if (scope != "global" || t.bits() * t.lanes() <= 128) { + this->CodeGenC::PrintVecStore(buffer, t, base, value); + return; + } + ICHECK_EQ(t.bits() * t.lanes(), 256) + << "Unsupported vector load size: " << t.bits() * t.lanes(); + auto buffer_ref = this->GetBufferRef(t, buffer, base); + this->PrintIndent(); + this->stream << "tl::st_global_256(&(" << buffer_ref << "), " << value + << ");\n"; +} + +/** + * @brief Emit CUDA/TensorLib-specific code for a call expression. + * + * This visitor handles CallNode intrinsics and builtins that require emitting + * CUDA/TL-specific code (inline PTX/ASM sequences, TensorLanguage runtime + * calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based + * stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The + * function writes the generated code to the provided output stream and falls + * back to the C codegen for unrecognized calls. + * + * The method recognizes and emits code for (non-exhaustive): cp.async and its + * commit/wait variants, tma_load/store and im2col variants, ptX + * ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy + * MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX + * asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret + * paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm + * and related external calls, and other TL runtime calls. + * + * Side effects: + * - Emits to `os` and the internal codegen output stream. + * - May set internal feature flags (e.g., need_cooperative_groups_, + * need_mma_h_, need_cast_smem_ptr_to_int_, enable_sparse_gemm_). + * - May open/close SSA scopes and mutate internal variable mappings. + * - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument + * patterns. + * + * @param op The call node to generate code for; the function inspects op->op + * and op->args to determine the appropriate emission. + * @param os Output stream to receive expression-level output when the caller + * expects an expression result (some paths write directly to the + * member stream instead). + */ +void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { + auto print_extern_call_stmt = [&](std::string name, size_t start = 0, + size_t end = 0) { + // Cache context into a private ss, otherwise the let node may generate + // within the function call arguments. + std::ostringstream ss; + + for (size_t i = start; i < op->args.size() - end; i++) { + if (i > start) + ss << ", "; + ss << this->PrintExpr(op->args[i]); + } + + this->PrintIndent(); + this->stream << name << "("; + this->stream << ss.str(); + this->stream << ");\n"; + }; + auto print_mbarrier_obj = [&](PrimExpr barrier_id) { + std::ostringstream ss; + if (barrier_id.as()) { + // incase the barrier_id is an integer, we need to print the barrier_id as + // an integer + ss << mbarrier_name_ << "[" << barrier_id << "]"; + } else { + // otherwise may be a T.get_mbarrier() call or BufferLoad Node + // we need to print the barrier_id as a string + ss << this->PrintExpr(barrier_id); + } + return ss.str(); + }; + if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = this->PrintExpr(op->args[0]); + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + // use size of argument list to indicate whether or not to use predicated + // cp.async + if (op->args.size() == 5) { + this->PrintIndent(); + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" + << dst_offset << ", " << src << "+" << src_offset << ");\n"; + } else { + std::string condition = this->PrintExpr(op->args[5]); + this->PrintIndent(); + this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst + << "+" << dst_offset << ", " << src << "+" << src_offset + << ", " << condition << ");\n"; + } + } else if (op->op.same_as(builtin::ptx_commit_group())) { + print_extern_call_stmt("tl::cp_async_commit"); + } else if (op->op.same_as(builtin::ptx_wait_group())) { + int n = Downcast(op->args[0])->value; + std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; + print_extern_call_stmt(func_name, 1); + } else if (op->op.same_as(builtin::create_barriers())) { + this->PrintIndent(); + int barrier_count = Downcast(op->args[0])->value; + auto mbarrier_storage_name = mbarrier_name_ + "_mem"; + this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "[" + << barrier_count << "];\n"; + this->PrintIndent(); + this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<" + << mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n"; + } else if (op->op.same_as(tl::get_mbarrier())) { + ICHECK_EQ(op->args.size(), 1); + std::string barrier_id = this->PrintExpr(op->args[0]); + os << mbarrier_name_ + "[" + barrier_id + "]"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + if (op->args.size() == 1) { + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + this->stream << mbarrier_obj << ".arrive();\n"; + } else if (op->args.size() == 3) { + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto cta_id = this->PrintExpr(op->args[1]); + auto pred = this->PrintExpr(op->args[2]); + this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred + << ");\n"; + } else { + LOG(FATAL) << "Invalid parameter for tl::arrive_barrier " + << op->args.size(); + } + } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { + ICHECK_EQ(op->args.size(), 2); + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto arrive_count = this->PrintExpr(op->args[1]); + this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { + if (op->args.size() == 2) { + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = this->PrintExpr(op->args[1]); + this->stream << mbarrier_obj << ".arrive_and_expect_tx(" + << transaction_bytes << ");\n"; + } else if (op->args.size() == 4) { + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = this->PrintExpr(op->args[1]); + auto cta_id = this->PrintExpr(op->args[2]); + auto pred = this->PrintExpr(op->args[3]); + this->stream << mbarrier_obj << ".arrive_and_expect_tx(" + << transaction_bytes << ", " << cta_id << ", " << pred + << ");\n"; + } else { + LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx " + << op->args.size(); + } + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); + } else if (op->op.same_as(tl::ptx_fence_barrier_init())) { + print_extern_call_stmt("tl::fence_barrier_init"); + } else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { + print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc"); + } else if (op->op.same_as(tl::mbarrier_expect_tx())) { + ICHECK_EQ(op->args.size(), 2); + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = this->PrintExpr(op->args[1]); + this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes + << ");\n"; + } else if (op->op.same_as(tl::mbarrier_wait_parity())) { + ICHECK_EQ(op->args.size(), 2); + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto phase = this->PrintExpr(op->args[1]); + this->stream << mbarrier_obj << ".wait(" << phase << ");\n"; + } else if (op->op.same_as(tl::ptx_init_tensor_memory())) { + print_extern_call_stmt("tl::tmem_allocate"); + } else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) { + print_extern_call_stmt("tl::tmem_deallocate"); + } else if (op->op.same_as(tl::no_set_max_nreg())) { + return; + } else if (op->op.same_as(tl::tma_load())) { + std::ostringstream ss; + ICHECK_GE(op->args.size(), 2); + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + // Simplify the code by using the default eviction policy + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_load("; + } else { + ss << "tl::tma_load("; + } + auto desc = op->args[0]; + ss << this->PrintExpr(desc) << ", "; + ss << print_mbarrier_obj(op->args[1]) << ", "; + for (size_t i = 2; i < op->args.size() - 1; i++) { + if (i > 2) + ss << ", "; + ss << this->PrintExpr(op->args[i]); + } + ss << ");\n"; + this->PrintIndent(); + this->stream << ss.str(); + } else if (op->op.same_as(tl::tma_load_im2col())) { + std::stringstream ss; + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_load_im2col"; + } else { + ss << "tl::tma_load_im2col"; + } + print_extern_call_stmt(ss.str(), 0, 1); + } else if (op->op.same_as(tl::tma_store())) { + std::stringstream ss; + auto need_reduce = op->args[op->args.size() - 2].as()->value; + if (need_reduce) { + print_extern_call_stmt("tl::tma_store_add", 0, 2); + return; + } + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_store"; + } else { + ss << "tl::tma_store"; + } + print_extern_call_stmt(ss.str(), 0, 2); + } else if (op->op.same_as(tl::ptx_ldmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::ptx_stmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::fence_proxy_async())) { + print_extern_call_stmt("tl::fence_proxy_async"); + } else if (op->op.same_as(tl::tma_store_arrive())) { + print_extern_call_stmt("tl::tma_store_arrive"); + } else if (op->op.same_as(tl::tma_store_wait())) { + print_extern_call_stmt("tl::tma_store_wait<0>"); + } else if (op->op.same_as(tl::warpgroup_arrive())) { + print_extern_call_stmt("tl::warpgroup_arrive"); + } else if (op->op.same_as(tl::warpgroup_commit_batch())) { + print_extern_call_stmt("tl::warpgroup_commit_batch"); + } else if (op->op.same_as(tl::warpgroup_wait())) { + this->PrintIndent(); + int num_mma = Downcast(op->args[0])->value; + this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma) + << ">();\n"; + } else if (op->op.same_as(tl::warpgroup_fence_operand())) { + ICHECK_EQ(op->args.size(), 4U); + std::string dtype = Downcast(op->args[0])->value; + std::string data_ptr = this->PrintExpr(op->args[1]); + std::string offset = this->PrintExpr(op->args[2]); + std::string num_regs = this->PrintExpr(op->args[3]); + auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype); + std::string cast_type = "uint32_t"; + if (dtype_enum == tl::codegen::ptx::DataType::kFloat32 || + dtype_enum == tl::codegen::ptx::DataType::kTensorFloat32) { + cast_type = "float"; + } + this->PrintIndent(); + this->stream << "tl::warpgroup_fence_operand(reinterpret_cast<" << cast_type + << "*>(" << data_ptr << " + " << offset << "), " << num_regs + << ");\n"; + } else if (op->op.same_as(tl::set_max_nreg())) { + this->PrintIndent(); + int nreg = Downcast(op->args[0])->value; + int is_inc = Downcast(op->args[1])->value; + std::string func_name = + is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; + this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; + } else if (op->op.same_as(tl::wait_wgmma())) { + this->PrintIndent(); + int num_mma = Downcast(op->args[0])->value; + this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; + } else if (op->op.same_as(tl::pack_b16())) { + os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " + << this->PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::sync_grid())) { + this->need_cooperative_groups_ = true; + this->PrintIndent(); + this->stream << "cooperative_groups::this_grid().sync();\n"; + } else if (op->op.same_as(tl::loop_break())) { + this->PrintIndent(); + this->stream << "break;\n"; + } else if (op->op.same_as(builtin::tvm_fill_fragment())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 6U); + os << "nvcuda::wmma::fill_fragment("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::load_matrix_sync("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[6], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::store_matrix_sync("; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[6], os); + if (const StringImmNode *str = op->args[7].as()) { + os << ", nvcuda::wmma::mem_" << str->value; + } else { + LOG(FATAL) << "Invalid parameters"; + } + os << ")"; + } else if (op->op.same_as(builtin::tvm_mma_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::mma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", " : ")"); + } + } else if (op->op.same_as(builtin::tvm_bmma_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::bmma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", " : ")"); + } + } else if (op->op.same_as(builtin::ptx_mma())) { + // arg 0: shape: mXnXkX + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: fp16, fp64, ... + // arg 4: B precision: fp16, fp64, ... + // arg 5: C precision: fp32, fp64, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + // arg 12: saturate + // arg 13: (optional) 1-bit operator (xor or and) + ICHECK(op->args.size() == 13U || op->args.size() == 14U); + std::string shape = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + + need_mma_instruction_h_ = true; + this->PrintIndent(); + std::string mma_call = + "tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), " + "(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "reinterpret_cast((A_ptr) + (A_offset)), " + "reinterpret_cast((B_ptr) + (B_offset)));\n"; + tl::codegen::Replacer replacer; + + // TODO(lei): Type Workaround for TF32, should be removed when + // we introduced tfloat32_t in the frontend. + std::string AType = tl::codegen::ptx::DTypeEnumToString(dtype_a_enum); + if (AType == "tl::DataType::kFloat32") { + AType = "tl::DataType::kTensorFloat32"; + } + std::string BType = tl::codegen::ptx::DTypeEnumToString(dtype_b_enum); + if (BType == "tl::DataType::kFloat32") { + BType = "tl::DataType::kTensorFloat32"; + } + std::string ARegType = tl::codegen::GetMMARegisterType(dtype_a_enum); + if (ARegType == "float") { + ARegType = "uint32_t"; + } + std::string BRegType = tl::codegen::GetMMARegisterType(dtype_b_enum); + if (BRegType == "float") { + BRegType = "uint32_t"; + } + + replacer.register_rule("(AType)", AType); + replacer.register_rule("(BType)", BType); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true"); + replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true"); + replacer.register_rule("(ARegType)", ARegType); + replacer.register_rule("(BRegType)", BRegType); + replacer.register_rule("(CRegType)", + tl::codegen::GetMMARegisterType(dtype_c_enum)); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", a_bias); + replacer.register_rule("(B_ptr)", b_ref); + replacer.register_rule("(B_offset)", b_bias); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_bias); + this->stream << replacer.rewrite(mma_call); + } else if (op->op.same_as(tl::ptx_mma_sm70())) { + // arg 0: shape: mXnXkX + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: fp16 + // arg 4: B precision: fp16 + // arg 5: C precision: fp16, fp32 + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + // arg 12: saturate + ICHECK_EQ(op->args.size(), 12U); + std::string shape = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + + need_mma_sm70_instruction_h_ = true; + this->PrintIndent(); + std::string mma_call = + "tl::mma_sync_sm70<(AType), (BType), (CType), (M), (N), (K), (TransA), " + "(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "reinterpret_cast((A_ptr) + (A_offset)), " + "reinterpret_cast((B_ptr) + (B_offset)));\n"; + tl::codegen::Replacer replacer; + + replacer.register_rule("(AType)", + tl::codegen::ptx::DTypeEnumToString(dtype_a_enum)); + replacer.register_rule("(BType)", + tl::codegen::ptx::DTypeEnumToString(dtype_b_enum)); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true"); + replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true"); + replacer.register_rule("(ARegType)", + tl::codegen::GetMMARegisterType(dtype_a_enum)); + replacer.register_rule("(BRegType)", + tl::codegen::GetMMARegisterType(dtype_b_enum)); + replacer.register_rule("(CRegType)", + tl::codegen::GetMMARegisterType(dtype_c_enum)); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", a_bias); + replacer.register_rule("(B_ptr)", b_ref); + replacer.register_rule("(B_offset)", b_bias); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_bias); + this->stream << replacer.rewrite(mma_call); + } else if (op->op.same_as(builtin::ptx_mma_sp())) { + // arg 0: shape: mXnXkX + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: fp16, fp32, ... + // arg 4: B precision: fp16, fp32, ... + // arg 5: C precision: fp16, fp32, ... + // arg 6: A multiplicand pointer + // arg 7: A multiplicand index + // arg 8: B multiplicand pointer + // arg 9: B multiplicand index + // arg 10: C accumulator pointer + // arg 11: C accumulator index + // arg 12: metadata + // arg 13: metadata index + // arg 14: sparse_selector + // arg 15: saturate + ICHECK_EQ(op->args.size(), 16U); + std::string shape = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_offset = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_offset = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_offset = this->PrintExpr(op->args[11]); + std::string metadata = this->PrintExpr(op->args[12]); + std::string metadata_offset = this->PrintExpr(op->args[13]); + std::string sparse_selector = this->PrintExpr(op->args[14]); + bool saturate = Downcast(op->args[15])->value; + this->PrintIndent(); + std::string asm_code = PrintMMAAssembly( + shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, + b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, + sparse_selector, "", true, saturate); + this->stream << asm_code; + } else if (op->op.same_as(tl::ptx_wgmma_ss())) { + // arg 0: dtype + // arg 1: shape + // arg 2: A_layout + // arg 3: B_layout + // arg 4: A_dtype + // arg 5: B_dtype + // arg 6: C_dtype + // arg 7: multiplicand_a + // arg 8: multiplicand_b + // arg 9: accumulator + // arg 10: saturate + ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_ss args is " << op->args; + std::string shape = Downcast(op->args[0])->value; + bool a_is_k_major = Downcast(op->args[1])->value; + bool b_is_k_major = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_desc = this->PrintExpr(op->args[6]); + std::string A_offset = this->PrintExpr(op->args[7]); + std::string b_desc = this->PrintExpr(op->args[8]); + std::string B_offset = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_offset = this->PrintExpr(op->args[11]); + std::string scale_out = this->PrintExpr(op->args[12]); + bool scale_in_a = Downcast(op->args[13])->value; + bool scale_in_b = Downcast(op->args[14])->value; + + const bool a_is_shared = true; + this->PrintIndent(); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + need_wgmma_instruction_h_ = true; + std::string wgmma_asm_code = + "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), " + "(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), " + "uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n"; + // replace patterns + tl::codegen::Replacer replacer; + + std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype); + if (AType == "tl::DataType::kFloat32") { + AType = "tl::DataType::kTensorFloat32"; + } + std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype); + if (BType == "tl::DataType::kFloat32") { + BType = "tl::DataType::kTensorFloat32"; + } + + replacer.register_rule("(AType)", AType); + replacer.register_rule("(BType)", BType); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(C_dtype)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(tnspA)", a_is_k_major ? "false" : "true"); + replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); + replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ref + " + " + c_offset); + replacer.register_rule("(scale_out)", scale_out); + wgmma_asm_code = replacer.rewrite(wgmma_asm_code); + this->stream << wgmma_asm_code; + } else if (op->op.same_as(tl::ptx_wgmma_rs())) { + // arg 0: shape + // arg 1: B_layout + // arg 2: A_dtype + // arg 3: B_dtype + // arg 4: C_dtype + // arg 5: multiplicand_a + // arg 6: multiplicand_a offset + // arg 7: multiplicand_b descriptor + // arg 8: multiplicand_b offset + // arg 9: accumulator + // arg 10: accumulator offset + // arg 11: scale_out + // arg 12: scale_in_a + // arg 13: scale_in_b + ICHECK_EQ(op->args.size(), 14U) << "ptx_wgmma_rs args is " << op->args; + std::string shape = Downcast(op->args[0])->value; + bool b_is_k_major = Downcast(op->args[1])->value; + std::string A_dtype = Downcast(op->args[2])->value; + std::string B_dtype = Downcast(op->args[3])->value; + std::string C_dtype = Downcast(op->args[4])->value; + std::string a_ref = this->PrintExpr(op->args[5]); + std::string A_offset = this->PrintExpr(op->args[6]); + std::string b_desc = this->PrintExpr(op->args[7]); + std::string B_offset = this->PrintExpr(op->args[8]); + std::string c_ref = this->PrintExpr(op->args[9]); + std::string c_offset = this->PrintExpr(op->args[10]); + std::string scale_out = this->PrintExpr(op->args[11]); + bool scale_in_a = Downcast(op->args[12])->value; + bool scale_in_b = Downcast(op->args[13])->value; + + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + + need_wgmma_instruction_h_ = true; + this->PrintIndent(); + std::string wgmma_call = + "tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), " + "(tnspB), (scaleA), (scaleB)>(reinterpret_cast((A_ptr) + (A_offset)), " + "uint64_t((desc_b) + (B_offset)), " + "reinterpret_cast((C_ptr) + (C_offset)), " + "(scale_out));\n"; + + tl::codegen::Replacer replacer; + std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype); + if (AType == "tl::DataType::kFloat32") { + AType = "tl::DataType::kTensorFloat32"; + } + std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype); + if (BType == "tl::DataType::kFloat32") { + BType = "tl::DataType::kTensorFloat32"; + } + + replacer.register_rule("(AType)", AType); + replacer.register_rule("(BType)", BType); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(tnspA)", "false"); + replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); + replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_offset); + replacer.register_rule("(scale_out)", scale_out); + wgmma_call = replacer.rewrite(wgmma_call); + this->stream << wgmma_call; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) { + ICHECK_EQ(op->args.size(), 14U) + << "ptx_tcgen05_mma_ss args is " << op->args; + std::string C_dtype = Downcast(op->args[0])->value; + std::string a_desc = this->PrintExpr(op->args[1]); + std::string A_offset = this->PrintExpr(op->args[2]); + std::string b_desc = this->PrintExpr(op->args[3]); + std::string B_offset = this->PrintExpr(op->args[4]); + std::string c_ref = this->PrintExpr(op->args[5]); + std::string c_offset = this->PrintExpr(op->args[6]); + PrimExpr desc_expr = op->args[7]; + std::string scale_out = this->PrintExpr(op->args[8]); + std::string mask0 = this->PrintExpr(op->args[9]); + std::string mask1 = this->PrintExpr(op->args[10]); + std::string mask2 = this->PrintExpr(op->args[11]); + std::string mask3 = this->PrintExpr(op->args[12]); + bool enable_ws = Downcast(op->args[13])->value; + + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + + need_tcgen05mma_instruction_h_ = true; + this->PrintIndent(); + std::string tcgen05_call = + "tl::(tcgen05_name)<(CType)>(uint64_t((desc_a) + (A_offset)), " + "uint64_t((desc_b) + (B_offset)), (*reinterpret_cast((C))) " + "+ (C_offset), " + "(scale_out), static_cast((desc_val)), (mask0), (mask1), " + "(mask2), (mask3));\n"; + tl::codegen::Replacer replacer; + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ref); + replacer.register_rule("(C_offset)", c_offset); + replacer.register_rule("(tcgen05_name)", + enable_ws ? "tcgen05mma_ws_ss" : "tcgen05mma_ss"); + replacer.register_rule("(scale_out)", scale_out); + replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr)); + replacer.register_rule("(mask0)", mask0); + replacer.register_rule("(mask1)", mask1); + replacer.register_rule("(mask2)", mask2); + replacer.register_rule("(mask3)", mask3); + tcgen05_call = replacer.rewrite(tcgen05_call); + this->stream << tcgen05_call; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) { + // TS: A from TMEM, B from SMEM (desc) + ICHECK_EQ(op->args.size(), 13U) + << "ptx_tcgen05_mma_ts args is " << op->args; + std::string kind_dtype = Downcast(op->args[0])->value; + std::string a_ref = this->PrintExpr(op->args[1]); + std::string A_offset = this->PrintExpr(op->args[2]); + std::string b_desc = this->PrintExpr(op->args[3]); + std::string B_offset = this->PrintExpr(op->args[4]); + std::string c_ref = this->PrintExpr(op->args[5]); + std::string c_offset = this->PrintExpr(op->args[6]); + PrimExpr desc_expr = op->args[7]; + std::string scale_out = this->PrintExpr(op->args[8]); + std::string mask0 = this->PrintExpr(op->args[9]); + std::string mask1 = this->PrintExpr(op->args[10]); + std::string mask2 = this->PrintExpr(op->args[11]); + std::string mask3 = this->PrintExpr(op->args[12]); + + auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype); + + need_tcgen05mma_instruction_h_ = true; + this->PrintIndent(); + std::string tcgen05_call = + "tl::tcgen05mma_ts<(CType)>( (*reinterpret_cast((A))) + " + "(A_offset), " + "uint64_t((desc_b) + (B_offset)), (*reinterpret_cast((C))) " + "+ (C_offset), " + "(scale_out), static_cast((desc_val)), (mask0), (mask1), " + "(mask2), (mask3));\n"; + tl::codegen::Replacer replacer; + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_enum)); + replacer.register_rule("(A)", a_ref); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ref); + replacer.register_rule("(C_offset)", c_offset); + replacer.register_rule("(scale_out)", scale_out); + replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr)); + replacer.register_rule("(mask0)", mask0); + replacer.register_rule("(mask1)", mask1); + replacer.register_rule("(mask2)", mask2); + replacer.register_rule("(mask3)", mask3); + tcgen05_call = replacer.rewrite(tcgen05_call); + this->stream << tcgen05_call; + } else if (op->op.same_as(tl::tcgen05_mma_arrive())) { + ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument"; + need_tcgen05_common_h_ = true; + this->PrintIndent(); + this->stream << "tl::tcgen05_mma_arrive(" << this->PrintExpr(op->args[0]) + << ");\n"; + } else if (op->op.same_as(builtin::ptx_ldmatrix())) { + // arg 0: whether the matrix is loaded in column major format or not. + // arg 1: number of matrices to load. + // arg 2: The data type in the matrix, .b16 is the only accepted data type. + // arg 3: pointer to local buffer. + // arg 4: The offset of the element to store in the local buffer. + // arg 5: pointer to the shared memory buffer to load. + // arg 6: The offset of the start element of the row to load in shared + // memory. + ICHECK_EQ(op->args.size(), 7U); + bool trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string type = Downcast(op->args[2])->value; + std::string local_ptr = this->PrintExpr(op->args[3]); + std::string local_elem_offset = this->PrintExpr(op->args[4]); + std::string smem_ptr = this->PrintExpr(op->args[5]); + if (trans && op->dtype.bits() == 8) { + // Since ldmatrix assumes that a matrix element is 16 bit, it cannot + // properly transpose an int8 matrix. + std::string smem_stride = this->PrintExpr(op->args[6]); + ICHECK(num == 4); + os << "for (int i = 0; i < 16; ++i) {\n"; + os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr + << "[(i % 8) / 4 * " + smem_stride + + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + + "+ (i % 4) * " + smem_stride + + " + threadIdx.x / 4 + (i / 8) * 8];\n"; + os << "}\n"; + } else { + std::string smem_elem_offset = this->PrintExpr(op->args[6]); + std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + this->PrintIndent(); + this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset + << ", " << local_ptr << " + " << local_elem_offset << ");\n"; + } + } else if (op->op.same_as(builtin::mma_store())) { + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; + std::string dst = this->PrintExpr(op->args[2]); + std::string src = this->PrintExpr(op->args[3]); + std::string src_offset = this->PrintExpr(op->args[4]); + PrimExpr stride = op->args[5]; + + ICHECK(m == 16 && n == 16) + << "Only m == 16 && n == 16 case supported for now"; + + // Each thread in a warp holds a certain number of elements of an MMA + // output. For example, if we compute a 16x16 tile using MMA, each thread + // holds 8 elements in its registers. So conceptually, a warp memory is + // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of + // memory is specified by the index map below. + + // To store the 32x8 output back to a 16x16 tile in shared or global memory, + // we invert this map to determine the output location for each 8 element. + + const auto index_map_func = ffi::Function::GetGlobal( + "tir.index_map.shared_16x16_to_mma_32x8_layout"); + + IndexMap index_map; + if (!index_map_func) { + Var i, j; + + // The index map is defined as follows: + index_map = IndexMap( + {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2), + 4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)}); + } else { + index_map = IndexMap::FromFunc(2, *index_map_func); + } + + arith::Analyzer analyzer; + auto inverse_index_map = + index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer); + auto indices_16x16 = inverse_index_map->final_indices; + + // "//" and "%" in the index map are translated to FloorDiv/Mod, but the + // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before + // they reach codegen, so manually replace them to the plain ones here. + class LowerFloorDivMod : public ExprMutator { + public: + PrimExpr VisitExpr_(const FloorDivNode *op) { + return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + PrimExpr VisitExpr_(const FloorModNode *op) { + return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + }; + + auto dst_ind = + LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); + + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; + if (op->dtype.bits() == 16) { + os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n"; + os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])" + << " = " + << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n"; + os << "}\n"; + } else { + os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" << " = " << src << "[" + << src_offset << " + local_id];\n"; + os << "}\n"; + } + + } else if (op->op.same_as(builtin::mma_fill())) { + std::string num_elem = this->PrintExpr(op->args[0]); + std::string dst = this->PrintExpr(op->args[1]); + std::string dst_offset = this->PrintExpr(op->args[2]); + + os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; + os << dst << "[" << dst_offset << " + i] = 0.0;"; + os << "}\n"; + } else if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = this->PrintExpr(op->args[0]); + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + need_cast_smem_ptr_to_int_ = true; + // use size of argument list to indicate whether or not to use predicated + // cp.async + if (op->args.size() == 5) { + this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, + size); + } else { + this->stream << PrintPredicatedCpAsyncAssembly( + dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5])); + } + } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { + need_cast_smem_ptr_to_int_ = true; + std::string dst = this->PrintExpr(op->args[0]); + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + int barrier_id = Downcast(op->args[5])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = + barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, + barrier); + } else if (op->op.same_as(builtin::ptx_commit_group())) { + this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; + } else if (op->op.same_as(builtin::ptx_wait_group())) { + int n = Downcast(op->args[0])->value; + this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n + << ";\");\n\n"; + } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { + need_cast_smem_ptr_to_int_ = true; + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = + barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + std::string thread_count = this->PrintExpr(op->args[1]); + this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + need_cast_smem_ptr_to_int_ = true; + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = + barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintArriveBarrierAsm(barrier); + } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { + need_cast_smem_ptr_to_int_ = true; + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = + barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + std::string byte_count = this->PrintExpr(op->args[1]); + this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count); + } else if (op->op.same_as(builtin::ptx_wait_barrier())) { + need_cast_smem_ptr_to_int_ = true; + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = + barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintWaitBarrierAsm(barrier); + } else if (op->op.same_as(builtin::ptx_ldg32())) { + /* + asm volatile ( + "{.reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + // " @p ld.global.nc.f32 %0, [%1];}\n"t + " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n" + : "=f"(reg) + : "l"(addr), "r"((int)guard) + ); + */ + + // get local + std::string reg = this->PrintExpr(op->args[0]); + // get guard + std::string guard = this->PrintExpr(op->args[1]); + const BufferLoadNode *addr_buffer = op->args[2].as(); + std::string global_addr = this->PrintExpr(addr_buffer->indices[0]); + std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data); + std::string local_addr = this->PrintExpr(op->args[3]); + this->stream << "asm volatile (\n"; + this->stream << "\"{.reg .pred p;\\n\"\n"; + this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n"; + this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; + this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n"; + // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; + stream << ": \"=f\"(" << reg << "[" << local_addr << "]" + << ")\n"; + stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr + << ")), \"r\"((int)" << guard << ")\n"; + stream << ");\n"; + } else if (op->op.same_as(tl::__ldg())) { + // Explicit read-only cached load. Preferred form: __ldg(BufferLoad(...)). + // Fallback form: __ldg(buffer, index) + const BufferLoadNode *bl = nullptr; + if (!op->args.empty()) { + bl = op->args[0].as(); + } + if (bl == nullptr) { + LOG(FATAL) << "T.__ldg expects a BufferLoad as the first argument."; + } + const BufferNode *buffer = bl->buffer.get(); + ICHECK_EQ(bl->indices.size(), 1) + << "T.__ldg currently supports flattened 1D buffer accesses."; + PrimExpr base = bl->indices[0]; + // Emit __ldg(&buffer_ref) + auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base); + os << "__ldg(&(" << buffer_ref << "))"; + } else if (op->op.same_as(builtin::reinterpret())) { + DataType tgt_dtype = op->dtype; + DataType src_dtype = op->args[0]->dtype; + PrimExpr value = op->args[0]; + + // Handle float4_e2m1fn reinterpret + if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) { + return CodeGenC::VisitExpr_(op, os); + } + if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() == + src_dtype.lanes() * src_dtype.bits()) { + return CodeGenC::VisitExpr_(op, os); + } + CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) + << "E2M1 float4 reinterpret expects source and target to have the same " + "number of lanes. " + << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; + CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) + << "E2M1 float4 reinterpret expects source and target to have the same " + "number of bytes. " + << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; + + int lanes = tgt_dtype.lanes(); + + int ssa_scope = BeginScope(); + if (lanes == 1) { + // The case of lane=1 is same as the normal reinterpret, + // except that we allow the src and dst dtype to have different number of + // bits. + std::string rhs = SSAGetID(PrintExpr(value), src_dtype); + os << "(*("; + this->PrintType(tgt_dtype, os); + os << " *)(&(" << rhs << ")))"; + } else if (lanes == 2) { + if (tgt_dtype.is_float4_e2m1fn()) { + // We view the source as an uint16, and then extract bits of two fp4 + // numbers, and finally reinterpret the result as fp4x2. + value = + tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}); + tir::Var temp_var("temp_var", DataType::UInt(16)); + value = + tir::Let(temp_var, value, + tir::Cast(DataType::UInt(8), + (temp_var & IntImm(DataType::UInt(16), 0xF)) | + ((temp_var >> 4) & + IntImm(DataType::UInt(16), 0xF0)))); + } else { + value = tir::Cast( + DataType::UInt(16), + tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value})); + tir::Var temp_var("temp_var", DataType::UInt(16)); + value = + tir::Let(temp_var, value, + (temp_var & IntImm(DataType::UInt(16), 0xF)) | + ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); + } + os << PrintExpr( + tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + } else if (lanes == 4) { + if (tgt_dtype.is_float4_e2m1fn()) { + // We view the source as an uint32, and then extract bits of four fp4 + // numbers, and finally reinterpret the result as fp4x4. + value = + tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value}); + tir::Var temp_var("temp_var", DataType::UInt(32)); + value = tir::Let( + temp_var, value, + tir::Cast( + DataType::UInt(16), + (temp_var & IntImm(DataType::UInt(32), 0xF)) | + ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | + ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | + ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); + } else { + value = tir::Cast(DataType::UInt(32), + tir::Call(DataType::UInt(16), + tir::builtin::reinterpret(), {value})); + tir::Var temp_var("temp_var", DataType::UInt(32)); + value = tir::Let( + temp_var, value, + (temp_var & IntImm(DataType::UInt(32), 0xF)) | + ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | + ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | + ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); + } + os << PrintExpr( + tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + } else { + LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " + << lanes; + } + EndScope(ssa_scope); + } else if (op->op.same_as(builtin::thread_return())) { + os << "return"; + } else if (op->op.same_as(tl::tl_gemm())) { + ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::tl_gemm_sp())) { + ICHECK(op->args.size() == 5) + << "tl_gemm_sp expects 5 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + enable_sparse_gemm_ = true; + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::get_lane_idx())) { + ICHECK_LE(op->args.size(), 1) + << "tl.get_lane_idx expects at most one argument ."; + os << "tl::get_lane_idx("; + if (!op->args.empty()) { + os << PrintExpr(op->args[0]); + } + os << ")"; + } else if (op->op.same_as(tl::get_warp_idx_sync())) { + ICHECK_LE(op->args.size(), 1) + << "tl.get_warp_idx_sync expects at most one argument ."; + os << "tl::get_warp_idx_sync("; + if (!op->args.empty()) { + os << PrintExpr(op->args[0]); + } + os << ")"; + } else if (op->op.same_as(tl::get_warp_idx())) { + ICHECK_LE(op->args.size(), 1) + << "tl.get_warp_idx expects at most one argument ."; + os << "tl::get_warp_idx("; + if (!op->args.empty()) { + os << PrintExpr(op->args[0]); + } + os << ")"; + } else if (op->op.same_as(tl::get_warp_group_idx())) { + ICHECK_LE(op->args.size(), 2) + << "tl.get_warp_group_idx expects ."; + os << "tl::get_warp_group_idx("; + for (size_t i = 0; i < op->args.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << PrintExpr(op->args[i]); + } + os << ")"; + } else if (op->op.same_as(tl::tl_shuffle_elect())) { + os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; + } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { + ICHECK(op->args.size() == 5) + << "tl_initialize_wgmma_descriptor expects 5 arguments but got " + << op->args.size(); + auto descriptor = op->args[0]; + auto start_address = op->args[1]; + auto layout_type = op->args[2]; + auto leading_byte_offset = op->args[3]; + auto stride_byte_offset = op->args[4]; + os << "tl::initialize_wgmma_descriptor<" << PrintExpr(layout_type) << ", " + << PrintExpr(leading_byte_offset) << ", " + << PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", " + << PrintExpr(start_address) << ")"; + } else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) { + ICHECK(op->args.size() == 7) + << "tl_initialize_tcgen05_descriptor expects 7 arguments but got " + << op->args.size(); + auto descriptor = op->args[0]; + auto start_address = op->args[1]; + auto leading_byte_offset = op->args[2]; + auto stride_byte_offset = op->args[3]; + auto base_offset = op->args[4]; + auto leading_abs = op->args[5]; + auto swizzle_mode = op->args[6]; + os << "tl::initialize_tcgen05_descriptor(" << PrintExpr(descriptor) << ", " + << PrintExpr(start_address) << ", " << PrintExpr(leading_byte_offset) + << ", " << PrintExpr(stride_byte_offset) << ", " + << PrintExpr(base_offset) << ", " << PrintExpr(leading_abs) << ", " + << PrintExpr(swizzle_mode) << ")"; + } else if (op->op.same_as(tl::increase_descriptor_offset())) { + ICHECK(op->args.size() == 2) + << "tl_increase_descriptor_offset expects 2 arguments but got " + << op->args.size(); + auto descriptor = op->args[0]; + auto offset = op->args[1]; + os << "tl::increase_descriptor_offset(" << PrintExpr(descriptor) + << ", " << PrintExpr(offset) << ")"; + } else if (op->op.same_as(tl::__exp())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "exp"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__exp10())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "exp10"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__log())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "log"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__log2())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "log2"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__log10())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "log10"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__tan())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "tan"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__cos())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "cos"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::__sin())) { + CUDAFastMath math_func; + std::string func_name = math_func(op->dtype, "sin"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::ieee_add())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[2])->value; + std::string func_name = math_func(op->dtype, "fadd", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::ieee_sub())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[2])->value; + std::string func_name = math_func(op->dtype, "fsub", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::ieee_mul())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[2])->value; + std::string func_name = math_func(op->dtype, "fmul", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::ieee_fmaf())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[3])->value; + std::string func_name = math_func(op->dtype, "fmaf", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")"; + } else if (op->op.same_as(tl::ieee_frcp())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[1])->value; + std::string func_name = math_func(op->dtype, "frcp", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::ieee_fsqrt())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[1])->value; + std::string func_name = math_func(op->dtype, "fsqrt", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::ieee_frsqrt())) { + CUDAIEEEMath math_func; + std::string func_name = math_func(op->dtype, "frsqrt", "rn"); + os << func_name << "(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::ieee_fdiv())) { + CUDAIEEEMath math_func; + std::string rounding_mode = Downcast(op->args[2])->value; + std::string func_name = math_func(op->dtype, "fdiv", rounding_mode); + os << func_name << "(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::rng_init())) { + this->need_curand_kernel_h_ = true; + this->curand_philox_state = name_supply_->FreshName("__philox_state"); + this->PrintIndent(); + this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state + << ";\n"; + this->PrintIndent(); + this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) + << ", &" << this->curand_philox_state << ");\n"; + // Store state_var for later use by rng_rand + } else if (op->op.same_as(tl::rng_rand())) { + this->need_curand_kernel_h_ = true; + os << "curand(&" << this->curand_philox_state << ")"; + } else if (op->op.same_as(tl::warp_reduce_sum())) { + os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_max())) { + os << "tl::warp_reduce_max(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_min())) { + os << "tl::warp_reduce_min(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_bitand())) { + os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_bitor())) { + os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tir::attr::fragment_shape) { + const VarNode *buffer = op->node.as(); + const StringImmNode *shape_str = op->value.as(); + fragment_shapes[buffer] = shape_str->value; + } else if (op->attr_key == tir::attr::fragment_layout) { + const VarNode *buffer = op->node.as(); + const StringImmNode *layout_str = op->value.as(); + fragment_layouts[buffer] = layout_str->value; + } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + const IntImmNode *queue_id = op->value.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + this->VisitStmt(op->body); + auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); + this->VisitExpr(commit_group, this->stream); + return; + } else if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto queue_id = wait_attrs.first.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + auto wait_cnt = wait_attrs.second; + auto wait_group = + Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); + this->VisitExpr(wait_group, this->stream); + auto inner = op->body.as(); + ICHECK(inner); + this->VisitStmt(inner->body); + return; + } else if (op->attr_key == "threadblock_swizzle_pattern") { + this->PrintIndent(); + const StringImmNode *pattern = op->value.as(); + ICHECK(pattern); + this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; + this->VisitStmt(op->body); + return; + } else if (op->attr_key == "pragma_unroll_factor") { + const IntImmNode *factor = op->value.as(); + ICHECK(factor); + unroll_factor[op->node.as()] = Downcast(factor); + } + + CodeGenC::VisitStmt_(op); +} + +void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + this->PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + const VarNode *buffer = op->buffer_var.as(); + if (scope.find("wmma.") == 0) { + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + ICHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) || + op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; + } else { + ICHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32)) + << "Accumulator only support half, float and int type for now"; + } + PrintWmmaScope(scope, op->dtype, buffer, stream); + } else if (scope == "local.descriptor.wgmma") { + stream << "tl::GmmaDescriptor " << vid << ";\n"; + } else if (scope == "local.descriptor.tcgen05_smem") { + stream << "tl::Tcgen05SMemDescriptor " << vid << ";\n"; + } else if (scope == "local.descriptor.tcgen05_instr") { + stream << "tl::Tcgen05InstrDescriptor " << vid << ";\n"; + } else { + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + } + + if (scope == "shared.dyn") { + stream << ' ' << vid << "[];\n"; + } else { + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now, but get " + << constant_size << " for " << op->buffer_var->name_hint; + if (scope.find("wmma.") == 0) { + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + } + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + if (scope == "shared") { + stream << ' ' << vid << '[' << constant_size << "];\n"; + } else if (scope == "shared.barrier") { + auto v_id_mem = vid + "_mem"; + stream << ' ' << v_id_mem << "[" << constant_size << "];\n"; + PrintIndent(); + stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_ + << "*>(" << v_id_mem << ");\n"; + } else if (scope == "local") { + stream << ' ' << vid << '[' << constant_size << "];\n"; + } else if (scope == "local.var") { + PrimExpr init = tir::make_const(op->dtype, 0); + auto init_it = op->annotations.find(tl::attr::kLocalVarInit); + if (init_it != op->annotations.end()) { + PrimExpr user_init = Downcast((*init_it).second); + if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) { + user_init = tir::Cast(op->dtype, user_init); + } + init = user_init; + } + stream << ' ' << vid << " = " << PrintExpr(init) << ";\n"; + } else if (scope.find("local.descriptor") != 0) { + ICHECK(false) << "Unsupported scope: " << scope; + } + } + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); +} + +void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { + if (is_const_int(op->value)) + return; + const CallNode *call = op->value.as(); + if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { + PrintIndent(); + stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; + PrintIndent(); + stream << "if (threadIdx.x == 0) {\n"; + PrintIndent(); + stream << " " << vid_global_barrier_expect_ << " = 0;\n"; + PrintIndent(); + stream << "}\n"; + } + if (call && (call->op.same_as(tvm::tl::device_assert()))) { + std::string cond = PrintExpr(call->args[0]); + this->PrintIndent(); + stream << "device_assert(" << cond << ");\n"; + } else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) { + std::string cond = PrintExpr(call->args[0]); + std::string msg_expr = PrintExpr(call->args[1]); + this->PrintIndent(); + stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n"; + } else { + CodeGenC::VisitStmt_(op); + } +} + +void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { + int lanes = static_cast(Downcast(op->lanes)->value); + CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef(op) + << " with " << lanes << " lanes is not allowed."; + os << "(make_"; + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < lanes; i++) { + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != lanes - 1) + os << ", "; + } + os << "))"; +} + +void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) + << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + int lanes = op->dtype.lanes(); + // declare type. + if (value_dtype.lanes() == element_dtype.lanes()) { + std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); + HandleVolatileLoads(ref, op, os); + } else { + bool can_vector_load = false; + arith::PVar base; + // For sub-byte types with lanes > 1 in element_dtype, adjust the ramp + // pattern + int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8) + ? value_dtype.lanes() / element_dtype.lanes() + : value_dtype.lanes(); + if (arith::ramp(base, 1, ramp_lanes).Match(index)) { + const RampNode *ramp = index.as(); + ICHECK(ramp); + can_vector_load = true; + // arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); + // The condition: {k * coeff + base} divisible by the alignment for any k + // if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() + // == 0) { + // can_vector_load = true; + // } + } + + if (can_vector_load) { + std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); + HandleVolatileLoads(ref, op, os); + } else { + std::ostringstream svalue_expr; + std::string sindex = SSAGetID(PrintExpr(index), index.dtype()); + std::string vid = GetVarID(buffer_var.get()); + DataType elem_type = op->dtype.element_of(); + for (int i = 0; i < lanes; ++i) { + std::ostringstream value_temp; + if (!HandleTypeMatch(buffer_var.get(), elem_type)) { + value_temp << "(("; + if (buffer_var.get()->dtype.is_handle()) { + auto it = alloc_storage_scope_.find(buffer_var.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, value_temp); + } + } + PrintType(elem_type, value_temp); + value_temp << "*)" << vid << ')'; + } else { + value_temp << vid; + } + value_temp << '['; + PrintVecElemLoad(sindex, index.dtype(), i, value_temp); + value_temp << ']'; + PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); + } + os << svalue_expr.str(); + } + } +} + +void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not supported."; + + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + + if (value_dtype.lanes() == element_dtype.lanes()) { + std::string value = this->PrintExpr(op->value); + std::string ref = + this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); + this->PrintIndent(); + stream << ref << " = " << value << ";\n"; + } else { + arith::PVar base; + // For sub-byte types with lanes > 1 in element_dtype, adjust the ramp + // pattern + int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8) + ? value_dtype.lanes() / element_dtype.lanes() + : value_dtype.lanes(); + + if (arith::ramp(base, 1, ramp_lanes).Match(index_expr)) { + std::string value = this->PrintExpr(op->value); + this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value); + } else { + // The assignment below introduces side-effect, and the resulting value + // cannot be reused across multiple expression, thus a new scope is needed + int vec_scope = BeginScope(); + + // store elements separately + std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype()); + std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); + std::string vid = GetVarID(buffer_var.get()); + for (int i = 0; i < value_dtype.lanes(); ++i) { + this->PrintIndent(); + DataType elem_type = value_dtype.element_of(); + if (!HandleTypeMatch(buffer_var.get(), elem_type)) { + stream << "(("; + if (buffer_var.get()->dtype.is_handle()) { + auto it = alloc_storage_scope_.find(buffer_var.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); + } + } + PrintType(elem_type, stream); + stream << "*)" << vid << ')'; + } else { + stream << vid; + } + stream << '['; + PrintVecElemLoad(index, index_expr.dtype(), i, stream); + stream << "] = "; + PrintVecElemLoad(value, op->value.dtype(), i, stream); + stream << ";\n"; + } + EndScope(vec_scope); + } + } +} + +void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + int lanes = static_cast(Downcast(op->lanes)->value); + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8) { + if (lanes == 4) { + // make_int8x4 + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + return; + } else if (lanes == 32) { + // make_int8x32 + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } else { + os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } + return; + } + } + + if (op->dtype.is_float16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + if (lanes <= 8) { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + } else { + for (int i = 0; i < lanes / 4; ++i) { + if (i != 0) + os << ", "; + os << "tl::pack_float16x4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } + } + os << ')'; + return; + } + + if (op->dtype.is_bfloat16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + if (lanes <= 8) { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + } else { + for (int i = 0; i < lanes / 4; ++i) { + if (i != 0) + os << ", "; + os << "tl::pack_bfloat16x4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } + } + os << ')'; + return; + } + + if (op->dtype.is_float() && op->dtype.bits() == 32 && + op->dtype.lanes() == 8) { + std::string v = PrintExpr(op->value); + os << "make_ulonglong4("; + for (int i = 0; i < 4; ++i) { + if (i != 0) + os << ", "; + os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + bool fail = false; + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xF; + + if (lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; + if (op->dtype.is_uint()) { + os << "(uint16_t)" << v; + } else { + os << "(int16_t)" << v; + } + } else { + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | + (v << 4) | v; + if (lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (lanes == 16 || lanes == 32) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 8; ++i) { + if (i != 0) + os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + + if (!fail) { + return; + } + } + + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << ')'; +} + +inline void PrintConst(const FloatImmNode *op, std::ostream &os, + CodeGenTileLangCUDA *p) { // NOLINT(*) + // Type code is kBFloat/kFloat16 + // which is indeed CUTLASS supported types currently + if (op->dtype.is_bfloat16() || op->dtype.is_float16()) { + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << "std::numeric_limits<"; + p->PrintType(op->dtype, temp); + temp << ">::infinity()"; + } else if (std::isnan(op->value)) { + temp << "std::numeric_limits<"; + p->PrintType(op->dtype, temp); + temp << ">::quiet_NaN()"; + } else { + p->PrintType(op->dtype, temp); + temp << '(' << std::hexfloat << op->value << 'f'; + temp << "/*" << std::scientific << op->value << "*/"; + temp << ')'; + } + p->MarkConst(temp.str()); + os << temp.str(); + return; + } + // Type code is kFloat8_e5m2 or kE4M4Float + if (op->dtype.is_float8() || op->dtype.is_float4()) { + p->PrintType(op->dtype, os); + os << '(' << std::hexfloat << op->value << 'f'; + os << "/*" << std::scientific << op->value << "*/"; + os << ')'; + return; + } + // Type code is kFloat64/kFloat32 (kFloat16 is handled above) + switch (op->dtype.bits()) { + case 64: + case 32: { + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + p->need_math_constants_h_ = true; + } else if (std::isnan(op->value)) { + temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + p->need_math_constants_h_ = true; + } else { + temp << std::hexfloat << op->value; + if (op->dtype.bits() == 32) + temp << 'f'; + temp << "/*" << std::scientific << op->value << "*/"; + } + p->MarkConst(temp.str()); + os << temp.str(); + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + PrintConst(op, os, this); +} + +void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t, + const VarNode *variable, + std::ostream &os) { + std::stringstream type; + PrintType(t, type); + ICHECK(fragment_shapes.count(variable)) + << "Cannot find shape of the wmma fragment " << variable->name_hint; + std::string shape_str = fragment_shapes.at(variable); + if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { + type.str(std::string()); + if (t.is_int()) { + if (t.bits() == 4) { + type << "nvcuda::wmma::experimental::precision::s4"; + } else if (t.bits() == 1) { + type << "nvcuda::wmma::experimental::precision::b1"; + } else { + LOG(FATAL) << "Unhandled integer type for wmma fragment!"; + } + } else if (t.is_uint()) { + if (t.bits() == 4) { + type << "nvcuda::wmma::experimental::precision::u4"; + } else { + LOG(FATAL) << "Unhandled integer type for wmma fragment!"; + } + } + } + if (scope == "wmma.matrix_a") { + std::string layout_str = fragment_layouts[variable]; + ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.matrix_b") { + std::string layout_str = fragment_layouts[variable]; + ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.accumulator") { + os << "nvcuda::wmma::fragment"; + } +} + +int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope, + const VarNode *variable, + int32_t size) { + ICHECK(fragment_shapes.count(variable)) + << "Cannot find shape of the wmma fragment " << variable->name_hint; + std::string shape_str = fragment_shapes.at(variable); + std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); + if (dim.first * dim.second != 0) + return size / dim.first / dim.second; + else + return 0; +} + +void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value, + const BufferLoadNode *op, + std::ostream &os) { + // Cast away volatile qualifier for fp16 types. That is, only loads and + // stores are volatile. The loaded objects are not marked as volatile. + // + if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && + IsVolatile(op->buffer->data.get())) { + os << "("; + PrintType(op->dtype, os); + os << ")(" << value << ")"; + } else { + os << value; + } +} + +void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, + const std::string &value, + std::ostream &os) { + ICHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (!(t.lanes() == 2 || t.lanes() == 3)) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 + << "))"; + return; + } + } + + if (t.is_float16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_half2(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (t.is_bfloat16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_bfloat162(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << "("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + return; +} + +void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name, + const PrimFunc &func, + std::ostream &os) { + PrintFuncPrefix(os); + CodeGenC::PrintType(func->ret_type, os); + CodeGenC::PrintExtraAttrs(func, os); + bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); + std::unordered_set non_restrict; + if (auto opt = + func->GetAttr>(tl::attr::kNonRestrictParams)) { + for (const tir::Var &v : opt.value()) + non_restrict.insert(v.get()); + } + // Read-only param indices attribute, if present. + std::unordered_set ro_param_indices; + if (auto opt = + func->GetAttr>("tl.readonly_param_indices")) { + for (const auto &idx : opt.value()) { + ro_param_indices.insert(static_cast(Downcast(idx)->value)); + } + } + os << " " << function_name << "("; + for (size_t i = 0; i < func->params.size(); ++i) { + tir::Var v = func->params[i]; + std::string vid = AllocVarID(v.get()); + + if (i > 0) { + os << ", "; + } + + if (v.dtype().is_handle()) { + // work around for grid constant parameters. + if (auto *ptr = v->type_annotation.as()) { + if (ptr->storage_scope == "grid_constant") { + os << "__grid_constant__ const "; + CodeGenC::PrintType(ptr->element_type, os); + os << ' ' << vid; + continue; + } + } + + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, os); + } + // If marked read-only, emit const qualifier before type. + if (ro_param_indices.count(static_cast(i))) { + os << "const "; + } + CodeGenC::PrintType(GetType(v), os); + if (auto *ptr = v->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } + + if (no_alias && !non_restrict.count(v.get())) { + PrintRestrict(v, os); + } + } else { + CodeGenC::PrintType(GetType(v), os); + } + os << ' ' << vid; + } + os << ")"; + + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + for (const auto ¶m : func->params) { + if (auto *ptr = param->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(param.get(), prim->dtype); + } + } + } +} + +void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, + const PrimFunc &f) { + // If the function has already been forward-declared, this is a + // no-op. + CodeGenC::DeclareFunction(gvar, f); + // clear previous generated state. + this->InitFuncState(f); + // reserve keywords + ReserveKeywordsAsUnique(); + + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); + std::unordered_set non_restrict; + if (auto opt = + f->GetAttr>(tl::attr::kNonRestrictParams)) { + for (const tir::Var &v : opt.value()) + non_restrict.insert(v.get()); + } + // Read-only param indices attribute, if present. + std::unordered_set ro_param_indices; + if (auto opt = f->GetAttr>("tl.readonly_param_indices")) { + for (const auto &idx : opt.value()) { + ro_param_indices.insert(static_cast(Downcast(idx)->value)); + } + } + + this->PrintFuncPrefix(stream); + CodeGenC::PrintType(f->ret_type, stream); + this->PrintExtraAttrs(f); + + this->stream << " " << static_cast(global_symbol.value()) << "("; + + for (size_t i = 0; i < f->params.size(); ++i) { + tir::Var v = f->params[i]; + std::string vid = AllocVarID(v.get()); + if (i != 0) + stream << ", "; + if (v.dtype().is_handle()) { + // work around for grid constant parameters. + if (auto *ptr = v->type_annotation.as()) { + if (ptr->storage_scope == "grid_constant") { + stream << "__grid_constant__ const "; + CodeGenC::PrintType(ptr->element_type, stream); + stream << ' ' << vid; + continue; + } + } + + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); + } + // If marked read-only, emit const qualifier before type. + if (ro_param_indices.count(static_cast(i))) { + stream << "const "; + } + CodeGenC::PrintType(GetType(v), stream); + if (auto *ptr = v->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } + + if (no_alias && !non_restrict.count(v.get())) { + PrintRestrict(v, stream); + } + } else { + CodeGenC::PrintType(GetType(v), stream); + } + stream << ' ' << vid; + } + stream << ") {\n"; + this->PreFunctionBody(f); + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_cuda.h b/tilelang/original/src/target/codegen_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..9cf4602133e3ab658d263110d269878b1db18eda --- /dev/null +++ b/tilelang/original/src/target/codegen_cuda.h @@ -0,0 +1,165 @@ +/*! + * \file target/codegen.h + * \brief Utility to generate code + */ +#ifndef TVM_TL_TARGET_CODEGEN_CUDA_H_ +#define TVM_TL_TARGET_CODEGEN_CUDA_H_ + +#include +#include +#include + +#include +#include + +#include "target/source/codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangCUDA final : public CodeGenC { +public: + CodeGenTileLangCUDA(); + std::string Finish(); + // override behavior + void PrintFuncPrefix(std::ostream &os) final; + void PrintExtraAttrs(const PrimFunc &f); + void VisitStmt_(const ForNode *op) final; + void PrintStorageSync(const CallNode *op) final; + void PrintStorageScope(const std::string &scope, + std::ostream &os) final; // NOLINT(*) + void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string &vec, DataType t, int i, + std::ostream &os) final; // NOLINT(*) + void PrintVecElemStore(const std::string &vec, DataType t, int i, + const std::string &value) final; + std::string GetVecLoad(DataType t, const BufferNode *buffer, + PrimExpr base) final; + void PrintVecStore(const BufferNode *buffer, DataType t, PrimExpr base, + const std::string &value) final; + void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) + void PrintVecElemLoadExpr(DataType t, int i, const std::string &value, + std::ostream &os) final; + std::string CastFromTo(std::string value, DataType from, + DataType target) final; + // overload visitor + void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; + void VisitExpr_(const CallNode *op, std::ostream &os) final; + void VisitExpr_(const CastNode *op, std::ostream &os) final; + void VisitExpr_(const MinNode *op, std::ostream &os) final; + void VisitExpr_(const MaxNode *op, std::ostream &os) final; + void VisitStmt_(const EvaluateNode *op) final; + void VisitStmt_(const AllocateNode *op) final; + void VisitStmt_(const AttrStmtNode *op) final; + void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final; + void VisitStmt_(const BufferStoreNode *op) final; + + // Override this as a work around for __grid_constant__ parameter + void AddFunction(const GlobalVar &gvar, const PrimFunc &f); + void PrintFunctionSignature(const ffi::String &function_name, + const PrimFunc &func, std::ostream &os); + +protected: + virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, + PrimExpr index) final; + void PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, bool skip_first_arg, + std::ostream &os) final; // NOLINT(*) + +private: + // Handle volatile loads + void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op, + std::ostream &os) final; + + // Whether scope such as "__shared__" or "__constant__" is part of type. + bool IsScopePartOfType() const final { return false; } + + friend void PrintConst(const FloatImmNode *op, std::ostream &os, + CodeGenTileLangCUDA *p); + + // Whether global barrier is needed. + bool need_global_barrier_{false}; + // Global barrier state + std::string vid_global_barrier_state_; + // Global barrier expected node. + std::string vid_global_barrier_expect_; + // Global curand state + std::string curand_philox_state; + + // whether enable fp16 + bool enable_fp16_{false}; + // whether enable bf16 + bool enable_bf16_{false}; + // whether enable fp8 + bool enable_fp8_{false}; + // whether enable fp6 + bool enable_fp6_{false}; + // whether enable fp4 + bool enable_fp4_{false}; + // whether enable int8 + bool enable_int8_{false}; + // whether enable sparse gemm + bool enable_sparse_gemm_{false}; + // whether enable warp shuffle intrinsics + bool enable_warp_shuffle_{false}; + // whether need math_constants.h + bool need_math_constants_h_{false}; + // whether need mma.h + bool need_mma_h_{false}; + // whether need tl mma instruction header + bool need_mma_instruction_h_{false}; + // whether need tl wgmma instruction header + bool need_wgmma_instruction_h_{false}; + // whether need tl tcgen05mma instruction header + bool need_tcgen05mma_instruction_h_{false}; + // whether need tl mma_sm70 instruction header + bool need_mma_sm70_instruction_h_{false}; + // whether need tcgen_05 common header + bool need_tcgen05_common_h_{false}; + // whether need cast_smem_ptr_to_int helper function + bool need_cast_smem_ptr_to_int_{false}; + // whether need cooperative_groups.h + bool need_cooperative_groups_{false}; + // whether need curand_kernel.h + bool need_curand_kernel_h_{false}; + // Op attribute map + OpAttrMap op_need_warp_shuffle_ = + Op::GetAttrMap("cuda.need_warp_shuffle"); + + // The name of the barrier array in shared memory + const std::string barrier_name_ = "barrier"; + // The size of the barrier array in shared memory + int barrier_count_ = -1; + // The name of the mbarrier array in shared memory + const std::string mbarrier_name_ = "mbarrier"; + // The type name of the mbarrier array + const std::string mbarrier_dtype_ = "Barrier"; + // The alignment of the barrier array in shared memory + // Set to 16 to maintain minimum alignment requirements for async bulk copy + const int barrier_alignment_bytes_ = 16; + + std::unordered_map fragment_shapes; + std::unordered_map fragment_layouts; + std::unordered_map unroll_factor; + friend void PrintConst(const FloatImmNode *op, std::ostream &os, + CodeGenTileLangCUDA *p); + void PrintWmmaScope(const std::string &scope, DataType t, + const VarNode *variable, std::ostream &os); + int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable, + int32_t size); + + std::vector eviction_policy_names_ = { + "EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"}; + std::unordered_set bf16_supported_ops_ = { + "bf1622float2", "bf1622int16", "float22bf162", "bf162bf162"}; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TL_TARGET_CODEGEN_CUDA_H_ diff --git a/tilelang/original/src/target/codegen_cutedsl.cc b/tilelang/original/src/target/codegen_cutedsl.cc new file mode 100644 index 0000000000000000000000000000000000000000..8279710de4bdf7b987cddca985b1869e8e18b5ae --- /dev/null +++ b/tilelang/original/src/target/codegen_cutedsl.cc @@ -0,0 +1,1355 @@ +/*! + * \file target/codegen_cutedsl.cc + */ + +#include "codegen_cutedsl.h" +#include "codegen_utils.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "arith/pattern_match.h" + +namespace tvm { +namespace codegen { +namespace { + +// The threshold of the loop extent to use cutlass.range_constexpr +// Higher values would lead to DSLOptimizationWarning: +// This static loop has 128 iterations, which may be very slow to compile, +// consider using `cutlass.range(..., unroll_full=True)` instead. +const int64_t LOOP_UNROLL_THRESHOLD = 64; + +void ReplaceAll(std::string &str, const std::string &from, + const std::string &to) { + ICHECK(!from.empty()) << "ReplaceAll(): `from` must be non-empty"; + auto pos = str.find(from); + while (pos != std::string::npos) { + str.replace(pos, from.size(), to); + pos = str.find(from, pos + to.size()); + } +} + +} // namespace + +CodeGenTileLangCuTeDSL::CodeGenTileLangCuTeDSL() { + // Read fastmath configuration from current PassContext + auto pass_ctx = tvm::transform::PassContext::Current(); + + // Read tl.enable_fast_math config, default to false + enable_fastmath_ = + pass_ctx->GetConfig(tl::kEnableFastMath, Bool(false)).value(); +} + +std::string CodeGenTileLangCuTeDSL::CanonicalizeFastmathFunctionName_( + const std::string &func_name) const { + static const std::unordered_map kFastMathMap = { + {"divf", "tl.divf"}, {"exp", "tl.exp"}, {"expf", "tl.exp"}, + {"exp2", "tl.exp2"}, {"exp2f", "tl.exp2"}, {"log", "tl.log"}, + {"logf", "tl.log"}, {"log2", "tl.log2"}, {"log2f", "tl.log2"}, + {"log10", "tl.log10"}, {"tan", "tl.tan"}, {"cos", "tl.cos"}, + {"sin", "tl.sin"}, {"sqrt", "tl.sqrt"}, {"sqrtf", "tl.sqrt"}, + }; + + auto it = kFastMathMap.find(func_name); + if (it != kFastMathMap.end()) { + return it->second; + } + return ""; +} + +void CodeGenTileLangCuTeDSL::PrintFuncDecorator_( + std::ostream &os) { // NOLINT(*) + os << "@cute.kernel\n"; +} + +void CodeGenTileLangCuTeDSL::PreFunctionBody_(const PrimFunc &f) { + PrintIndent(); + stream << "threadIdx = tl.ThreadIdx()" << "\n"; + PrintIndent(); + stream << "blockIdx = tl.BlockIdx()" << "\n"; +} + +namespace { +std::string DTypeToString(DataType t) { + ICHECK(t.is_scalar()) << "unsupported type " << t; + + if (t.is_void()) { + return "void"; + } + if (t == tl::cuTensorMapType()) { + return "CUtensorMap"; + } + + int bits = t.bits(); + std::string elem_type; + if (t.is_float()) { + if (bits == 16 || bits == 32 || bits == 64) { + elem_type = "Float" + std::to_string(bits); + } + } else if (t.is_bfloat16()) { + elem_type = "BFloat16"; + } else if (t.is_float8()) { + if (t.is_float8_e3m4()) { + // unsupported + } else if (t.is_float8_e4m3()) { + elem_type = + "Float8E4M3FN"; // Only Float8E4M3FN is supported at the moment + } else if (t.is_float8_e4m3b11fnuz()) { + // unsupported + } else if (t.is_float8_e4m3fn()) { + elem_type = "Float8E4M3FN"; + } else if (t.is_float8_e4m3fnuz()) { + // unsupported + } else if (t.is_float8_e5m2()) { + elem_type = "Float8E5M2"; + } else if (t.is_float8_e5m2fnuz()) { + // unsupported + } else if (t.is_float8_e8m0fnu()) { + elem_type = "Float8E8M0FNU"; + } + } else if (t.is_float6()) { + if (t.is_float6_e3m2fn()) { + elem_type = "Float6E3M2FN"; + } else if (t.is_float6_e2m3fn()) { + elem_type = "Float6E2M3FN"; + } + } else if (t.is_float4()) { + if (t.is_float4_e2m1fn()) { + elem_type = "Float4E2M1FN"; + } + } else if (t.is_bool()) { + elem_type = "Boolean"; + } else if (t.is_uint()) { + if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 128) { + elem_type = "Uint" + std::to_string(bits); + } + } else if (t.is_int()) { + if (bits == 4 || bits == 8 || bits == 16 || bits == 32 || bits == 64 || + bits == 128) { + elem_type = "Int" + std::to_string(bits); + } + } + + if (elem_type.empty()) { + LOG(FATAL) << "Cannot convert type " << t << " to CuTeDSL type!"; + } + + return "cutlass." + elem_type; +} +} // namespace + +void CodeGenTileLangCuTeDSL::PrintType(DataType t, + std::ostream &os) { // NOLINT(*) + CHECK(t.is_scalar()) << "Should not print a non-scalar type in CuTeDSL: " + << t; + os << DTypeToString(t); +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + os << "tl.make_filled_tensor((" << PrintExpr_(op->lanes) << ",), " + << PrintExpr_(op->value) << ").load()"; +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + switch (op->dtype.bits()) { + case 64: + case 32: + case 16: + case 8: + case 4: { + std::ostringstream temp; + if (std::isinf(op->value)) { + // For CuTeDSL, use Python's float('inf') instead of CUDA macros + PrintType(op->dtype, temp); + temp << "("; + if (op->value < 0) { + temp << "float('-inf')"; + } else { + temp << "float('inf')"; + } + temp << ")"; + } else if (std::isnan(op->value)) { + // For CuTeDSL, use Python's float('nan') + PrintType(op->dtype, temp); + temp << "(float('nan'))"; + } else { + // For CuTeDSL, use Python's float.fromhex() with hexfloat for full + // precision + PrintType(op->dtype, temp); + temp << "(float.fromhex('" << std::hexfloat << op->value << "'))"; + } + MarkConst(temp.str()); + os << temp.str(); + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const CastNode *op, + std::ostream &os) { // NOLINT(*) + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); + + if (from_ty.is_scalar()) + return CodeGenTileLangPY::VisitExpr_(op, os); + + // Emit this as vectorized unary ops. + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << target_ty.lanes() << ",), "; + PrintType(target_ty.element_of(), stream); + stream << ")\n"; + + std::string src = SSAGetID(PrintExpr_(op->value), from_ty); + + PrintIndent(); + stream << sret << ".store(" << src << ".to("; + PrintType(target_ty.element_of(), stream); + stream << "))\n"; + os << sret << ".load()"; + return; +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const DivNode *op, + std::ostream &os) { // NOLINT(*) + if (op->dtype.is_int() || op->dtype.is_uint()) { + PrintBinaryExpr_("//", op->dtype, op->a, op->b, os); + } else { + if (enable_fastmath_) { + os << "tl.divf(" << PrintExpr_(op->a) << ", " << PrintExpr_(op->b) + << ", fastmath=True)"; + } else { + PrintBinaryExpr_("tl.divf", op->dtype, op->a, op->b, os); + } + } +} +void CodeGenTileLangCuTeDSL::VisitExpr_(const MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("tl.min", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangCuTeDSL::VisitExpr_(const MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("tl.max", op->dtype, op->a, op->b, os); +} + +/** + * @brief Emit CuTeDSL-specific code for a call expression. + * + * This visitor handles CallNode intrinsics and builtins that require emitting + * CuTeDSL-specific code (inline PTX/ASM sequences, TensorLanguage runtime + * calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based + * stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The + * function writes the generated code to the provided output stream and falls + * back to the Python codegen for unrecognized calls. + * + * The method recognizes and emits code for (non-exhaustive): cp.async and its + * commit/wait variants, tma_load/store and im2col variants, ptX + * ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy + * MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX + * asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret + * paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm + * and related external calls, and other TL runtime calls. + * + * Side effects: + * - Emits to `os` and the internal codegen output stream. + * - May set internal feature flags (e.g., need_cooperative_groups_). + * - May open/close SSA scopes and mutate internal variable mappings. + * - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument + * patterns. + * + * @param op The call node to generate code for; the function inspects op->op + * and op->args to determine the appropriate emission. + * @param os Output stream to receive expression-level output when the caller + * expects an expression result (some paths write directly to the + * member stream instead). + */ +void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + auto print_extern_call_stmt = [&](std::string name, size_t start = 0, + size_t end = 0) { + // Cache context into a private ss, otherwise the let node may generate + // within the function call arguments. + std::ostringstream ss; + for (size_t i = start; i < op->args.size() - end; i++) { + if (i > start) + ss << ", "; + ss << PrintExpr_(op->args[i]); + } + + PrintIndent(); + stream << name << "("; + stream << ss.str(); + stream << ")\n"; + }; + + auto print_mbarrier_obj = [&](PrimExpr barrier_id) { + std::ostringstream ss; + if (barrier_id.as()) { + // incase the barrier_id is an integer, we need to print the barrier_id as + // an integer + ss << "(" << mbarrier_name_ << "+" << barrier_id << ")"; + } else { + // otherwise may be a T.get_mbarrier() call or BufferLoad Node + // we need to print the barrier_id as a string + ss << PrintExpr_(barrier_id); + } + return ss.str(); + }; + + if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = PrintExpr_(op->args[0]); + std::string dst_offset = PrintExpr_(op->args[1]); + std::string src = PrintExpr_(op->args[2]); + std::string src_offset = PrintExpr_(op->args[3]); + std::string size = PrintExpr_(op->args[4]); + // use size of argument list to indicate whether or not to use predicated + // cp.async + if (op->args.size() == 5) { + PrintIndent(); + stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset + << ", " << src << ", " << src_offset << ")\n"; + } else { + std::string condition = PrintExpr_(op->args[5]); + PrintIndent(); + stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", " + << dst_offset << ", " << src << ", " << src_offset << ", " + << condition << ")\n"; + } + } else if (op->op.same_as(builtin::ptx_commit_group())) { + print_extern_call_stmt("tl.cp_async_commit"); + } else if (op->op.same_as(builtin::ptx_wait_group())) { + print_extern_call_stmt("tl.cp_async_wait"); + } else if (op->op.same_as(builtin::create_barriers())) { + PrintIndent(); + int barrier_count = Downcast(op->args[0])->value; + stream << mbarrier_name_ + << " = tl.alloc_smem(cutlass.Uint64, size_in_elems=" << barrier_count + << ")\n"; + } else if (op->op.same_as(tl::get_mbarrier())) { + ICHECK_EQ(op->args.size(), 1); + std::string barrier_id = PrintExpr_(op->args[0]); + os << "(" << mbarrier_name_ << "+" << barrier_id << ")"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + if (op->args.size() == 1) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + stream << "tl.mbarrier_arrive(" << mbarrier_obj << ")\n"; + } else if (op->args.size() == 3) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto cta_id = PrintExpr_(op->args[1]); + auto pred = PrintExpr_(op->args[2]); + stream << "tl.mbarrier_arrive(" << mbarrier_obj << ", " << cta_id << ", " + << pred << ")\n"; + } else { + LOG(FATAL) << "Invalid parameter for tl::arrive_barrier " + << op->args.size(); + } + } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { + ICHECK_EQ(op->args.size(), 2); + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto arrive_count = PrintExpr_(op->args[1]); + stream << "tl.mbarrier_init(" << mbarrier_obj << ", " << arrive_count + << ")\n"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { + if (op->args.size() == 2) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = PrintExpr_(op->args[1]); + stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ")\n"; + } else if (op->args.size() == 4) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = PrintExpr_(op->args[1]); + auto cta_id = PrintExpr_(op->args[2]); + auto pred = PrintExpr_(op->args[3]); + stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ", " << cta_id << ", " << pred << ")\n"; + } else { + LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx " + << op->args.size(); + } + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + print_extern_call_stmt("tl.mbarrier_cp_async_arrive"); + } else if (op->op.same_as(tl::ptx_fence_barrier_init())) { + print_extern_call_stmt("tl.fence_barrier_init"); + } else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { + print_extern_call_stmt("tl.mbarrier_cp_async_arrive_noinc"); + } else if (op->op.same_as(tl::mbarrier_expect_tx())) { + ICHECK_EQ(op->args.size(), 2); + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = PrintExpr_(op->args[1]); + stream << "tl.mbarrier_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ")\n"; + } else if (op->op.same_as(tl::mbarrier_wait_parity())) { + ICHECK_EQ(op->args.size(), 2); + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto phase = PrintExpr_(op->args[1]); + stream << "tl.mbarrier_wait(" << mbarrier_obj << ", " << phase << ")\n"; + } else if (op->op.same_as(tl::ptx_init_tensor_memory())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::no_set_max_nreg())) { + // do nothing + } else if (op->op.same_as(tl::tma_load())) { + std::ostringstream ss; + ICHECK_GE(op->args.size(), 2); + auto pol = op->args[op->args.size() - 1].as(); + ICHECK(pol) << "Eviction policy must be IntImm"; + ICHECK_GE(pol->value, 0); + ICHECK_LT(static_cast(pol->value), eviction_policy_names_.size()); + auto eviction_policy = eviction_policy_names_[pol->value]; + // Simplify the code by using the default eviction policy + if (eviction_policy != "EVICT_NORMAL") { + LOG(FATAL) << "Eviction policy " << eviction_policy + << " is not supported currently"; + } else { + ss << "tl.tma_load("; + } + auto desc = op->args[0]; + ss << PrintExpr_(desc) << ", "; + ss << print_mbarrier_obj(op->args[1]) << ", "; + ss << PrintExpr_(op->args[2]) << ", ("; + for (size_t i = 3; i < op->args.size() - 1; i++) { + if (i > 3) + ss << ", "; + ss << PrintExpr_(op->args[i]); + } + ss << "))\n"; + PrintIndent(); + stream << ss.str(); + } else if (op->op.same_as(tl::tma_load_im2col())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::tma_store())) { + std::stringstream ss; + // Check minimum argument count (desc, data, at least one coord, + // need_reduce, eviction) + ICHECK_GE(op->args.size(), 4) << "tma_store requires at least 4 arguments " + "(desc, data, coords..., need_reduce, " + "eviction_policy), got " + << op->args.size(); + + // Safely extract need_reduce flag + auto need_reduce_ptr = op->args[op->args.size() - 2].as(); + ICHECK(need_reduce_ptr) + << "tma_store need_reduce flag (args[-2]) must be IntImm, got " + << op->args[op->args.size() - 2]->GetTypeKey(); + auto need_reduce = need_reduce_ptr->value; + if (need_reduce) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } + + // Safely extract and validate eviction policy index + auto eviction_idx_ptr = op->args[op->args.size() - 1].as(); + ICHECK(eviction_idx_ptr) + << "tma_store eviction policy (args[-1]) must be IntImm, got " + << op->args[op->args.size() - 1]->GetTypeKey(); + ICHECK_GE(eviction_idx_ptr->value, 0) + << "tma_store eviction policy index must be >= 0, got " + << eviction_idx_ptr->value; + ICHECK_LT(static_cast(eviction_idx_ptr->value), + eviction_policy_names_.size()) + << "tma_store eviction policy index " << eviction_idx_ptr->value + << " out of bounds (max " << eviction_policy_names_.size() - 1 << ")"; + auto eviction_policy = eviction_policy_names_[eviction_idx_ptr->value]; + + ss << "tl.tma_store("; + auto desc = op->args[0]; + ss << PrintExpr_(desc) << ", "; + ss << PrintExpr_(op->args[1]) << ", ("; + for (size_t i = 2; i < op->args.size() - 2; i++) { + if (i > 2) + ss << ", "; + ss << PrintExpr_(op->args[i]); + } + ss << ")"; + if (eviction_policy != "EVICT_NORMAL") { + ss << ", eviction_kind = nvvm.EvictKind." << eviction_policy.substr(6); + } + ss << ")\n"; + PrintIndent(); + stream << ss.str(); + } else if (op->op.same_as(tl::ptx_ldmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl.ptx_ldmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::ptx_stmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl.ptx_stmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::fence_proxy_async())) { + print_extern_call_stmt("tl.fence_proxy_async"); + } else if (op->op.same_as(tl::tma_store_arrive())) { + print_extern_call_stmt("tl.tma_store_arrive"); + } else if (op->op.same_as(tl::tma_store_wait())) { + PrintIndent(); + stream << "tl.tma_store_wait(0)\n"; + } else if (op->op.same_as(tl::warpgroup_arrive())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warpgroup_commit_batch())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warpgroup_wait())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warpgroup_fence_operand())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::set_max_nreg())) { + PrintIndent(); + int nreg = Downcast(op->args[0])->value; + int is_inc = Downcast(op->args[1])->value; + std::string func_name = + is_inc ? "tl.warpgroup_reg_alloc" : "tl.warpgroup_reg_dealloc"; + stream << func_name << "(" << nreg << ")\n"; + } else if (op->op.same_as(tl::wait_wgmma())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::pack_b16())) { + os << "tl.pack_half2(" << PrintExpr_(op->args[0]) << ", " + << PrintExpr_(op->args[1]) << ")"; + } else if (op->op.same_as(tl::sync_grid())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::loop_break())) { + PrintIndent(); + stream << "break\n"; + } else if (op->op.same_as(builtin::ptx_mma())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_mma_sm70())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_mma_sp())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_wgmma_ss())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_wgmma_rs())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::tcgen05_mma_arrive())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_ldmatrix())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::mma_store())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::mma_fill())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_wait_barrier())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_ldg32())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::reinterpret())) { + DataType tgt_dtype = op->dtype; + DataType src_dtype = op->args[0]->dtype; + ICHECK_EQ(tgt_dtype.lanes() * tgt_dtype.bits(), + src_dtype.lanes() * src_dtype.bits()) + << "reinterpret expects source and target to have the same number of " + "bits"; + + const BufferLoadNode *load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) + << "CodeGenTileLangCuTeDSL only supports flat memory"; + + PrimExpr index = load->indices[0]; + if (const RampNode *node = index.as(); node) { + auto *p_stride = as_const_int(node->stride); + CHECK(p_stride); + ICHECK_EQ(*p_stride, 1) << "reinterpret expects contiguous elements"; + index = node->base; + } + + auto ptr_str = GetBufferPtr_(load->buffer.get(), index); + os << "tl.make_tensor(tl.recast_ptr(" << ptr_str << ", dtype="; + PrintType(tgt_dtype.element_of(), os); + os << "), (" << tgt_dtype.lanes() << ",)).load()"; + } else if (op->op.same_as(builtin::thread_return())) { + os << "return"; + } else if (op->op.same_as(tl::tl_gemm())) { + ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " + << op->args.size(); + + auto op_instance = Downcast(op->args[0]); + PrintCallExtern_(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::tl_gemm_sp())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_lane_idx())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_warp_idx_sync())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_warp_idx())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_warp_group_idx())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::tl_shuffle_elect())) { + os << "tl.shuffle_elect(" << PrintExpr_(op->args[0]) << ")"; + } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::increase_descriptor_offset())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::__exp())) { + os << "tl.exp2(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__exp10())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::__log())) { + os << "tl.log(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__log2())) { + os << "tl.log2(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__log10())) { + os << "tl.log10(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__tan())) { + os << "tl.tan(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__cos())) { + os << "tl.cos(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__sin())) { + os << "tl.sin(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::ieee_add())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_sub())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_mul())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_fmaf())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_frcp())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_fsqrt())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_frsqrt())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_fdiv())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_sum())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_max())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_min())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_bitand())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_bitor())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::address_of())) { + const BufferLoadNode *load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) + << "CodeGenTileLangCuTeDSL only supports flat memory"; + os << GetBufferPtr_(load->buffer.get(), load->indices[0]); + } else { + CodeGenTileLangPY::VisitExpr_(op, os); + } +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const BufferLoadNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) + << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + const int value_lanes = value_dtype.lanes(); + if (value_lanes == element_dtype.lanes()) { + std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index); + if (ref.back() == ')') { + ref += ".load()"; + } + os << ref; + } else { + ICHECK_GE(value_lanes, element_dtype.lanes()) + << "Unsupported load/store: value lanes < buffer element lanes"; + bool is_contiguous = false; + arith::PVar base; + if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()) + .Match(index)) { + is_contiguous = true; + } + + if (is_contiguous) { + std::string ref = + GetBufferRef_(value_dtype, op->buffer.get(), base.Eval()); + if (ref.back() == ')') { + ref += ".load()"; + } + os << ref; + } else { + ICHECK(element_dtype.is_scalar()) + << "buffer element type for non-contiguous load must be scalar " + "currently"; + + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << value_lanes << ",), "; + PrintType(element_dtype, stream); + stream << ")\n"; + + std::string vid = GetVarID(buffer_var.get()); + const RampNode *ramp = index.as(); + ICHECK(ramp) + << "Expected Ramp index for vectorized non-contiguous access"; + for (int i = 0; i < value_lanes; ++i) { + auto idx_expr = + arith::Analyzer().Simplify(ramp->base + ramp->stride * i); + + PrintIndent(); + stream << sret << "[" << i << "] = " + << GetBufferRef_(element_dtype, op->buffer.get(), idx_expr) + << "\n"; + } + os << sret << ".load()"; + } + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const BufferStoreNode *op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not supported."; + + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + std::string value_str = PrintExpr_(op->value); + + int value_lanes = value_dtype.lanes(); + if (value_lanes == element_dtype.lanes()) { + std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr); + PrintIndent(); + + if (ref.back() != ')') { + stream << ref << " = " << RemoveOutermostParentheses(value_str) << "\n"; + } else { + stream << ref << ".store(" << RemoveOutermostParentheses(value_str) + << ")\n"; + } + } else { + bool is_contiguous = false; + arith::PVar base; + if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()) + .Match(index_expr)) { + is_contiguous = true; + } + + if (is_contiguous) { + PrintVecStore_(op->buffer.get(), value_dtype, base.Eval(), value_str); + } else { + ICHECK(element_dtype.is_scalar()) + << "buffer element type for non-contiguous store must be scalar " + "currently"; + + // store elements separately + value_str = SSAGetID(value_str, element_dtype); + for (int i = 0; i < value_lanes; ++i) { + const RampNode *ramp = index_expr.as(); + ICHECK(ramp); + auto idx_expr = + arith::Analyzer().Simplify(ramp->base + ramp->stride * i); + + PrintIndent(); + stream << GetBufferRef_(element_dtype, op->buffer.get(), idx_expr) + << " = "; + PrintVecElemLoad_(value_str, value_dtype, i, stream); + stream << "\n"; + } + } + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + + if (scope == "local.descriptor.wgmma") { + stream << vid << " = tl.GmmaDescriptor()\n"; + } else if (scope == "local.descriptor.tcgen05_smem") { + LOG(FATAL) << "Currently unsupported scope: " << scope; + } else if (scope == "local.descriptor.tcgen05_instr") { + LOG(FATAL) << "Currently unsupported scope: " << scope; + } else if (scope == "shared.dyn") { + stream << vid << " = tl.make_tensor(tl.get_dyn_smem("; + PrintType(op->dtype, stream); + // there is no bound check for Tensor access, so just set shape to 1 + stream << ", alignment=1024), (1,))\n"; + } else { + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now, but get " + << constant_size << " for " << op->buffer_var->name_hint; + + if (scope == "shared") { + stream << vid << " = tl.make_tensor(tl.alloc_smem("; + PrintType(op->dtype, stream); + stream << ", " << constant_size << "), (" << constant_size << ",))\n"; + } else if (scope == "shared.barrier") { + ICHECK(false) << "Unsupported scope: " << scope; + } else if (scope == "local") { + stream << vid << " = tl.make_rmem_tensor((" << constant_size << "),"; + PrintType(op->dtype, stream); + stream << ")\n"; + } else if (scope == "local.var") { + PrimExpr init = tir::make_const(op->dtype, 0); + auto init_it = op->annotations.find(tl::attr::kLocalVarInit); + if (init_it != op->annotations.end()) { + PrimExpr user_init = Downcast((*init_it).second); + if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) { + user_init = tir::Cast(op->dtype, user_init); + } + init = user_init; + } + stream << vid << " = " << PrintExpr_(init) << "\n"; + } else { + ICHECK(false) << "Unsupported scope: " << scope; + } + } + + RegisterHandleType_(op->buffer_var.get(), op->dtype); + PrintStmt_(op->body); +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (!iv->thread_tag.empty()) { + if (!var_idmap_.count(iv->var.get())) { + BindThreadIndex_(iv); + } + } + VisitStmt(op->body); + } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + const IntImmNode *queue_id = op->value.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + VisitStmt(op->body); + auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); + VisitExpr(commit_group, stream); + } else if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto queue_id = wait_attrs.first.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + auto wait_cnt = wait_attrs.second; + auto wait_group = + Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); + VisitExpr(wait_group, stream); + auto inner = op->body.as(); + ICHECK(inner); + VisitStmt(inner->body); + } else if (op->attr_key == "threadblock_swizzle_pattern") { + this->PrintIndent(); + const StringImmNode *pattern = op->value.as(); + ICHECK(pattern); + std::string call_str = pattern->value; + // replace :: with . and replace < with ( and replace > with ) + ReplaceAll(call_str, "::", "."); + ReplaceAll(call_str, "<", "("); + ReplaceAll(call_str, ">", ")"); + this->stream << "blockIdx = " << call_str << "\n"; + this->VisitStmt(op->body); + } else if (op->attr_key == "pragma_unroll_factor") { + const IntImmNode *factor = op->value.as(); + ICHECK(factor); + unroll_factor_[op->node.as()] = Downcast(factor); + CodeGenTileLangPY::VisitStmt_(op); + } else { + CodeGenTileLangPY::VisitStmt_(op); + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const ForNode *op) { + if (op->kind != tir::ForKind::kUnrolled) { + CodeGenTileLangPY::VisitStmt_(op); + return; + } + + auto start_expr = arith::Analyzer().Simplify(op->min); + auto stop_expr = arith::Analyzer().Simplify(op->extent + op->min); + std::string unroll_factor; + if (auto it = unroll_factor_.find(op->loop_var.get()); + it != unroll_factor_.end()) { + unroll_factor = PrintExpr_(it->second); + } + bool use_range_constexpr = unroll_factor.empty() && + as_const_int(op->extent) != nullptr && + *as_const_int(op->extent) <= LOOP_UNROLL_THRESHOLD; + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + stream << "for " << vid << " in cutlass.range"; + if (use_range_constexpr) { + stream << "_constexpr"; + } + stream << "("; + if (!is_zero(start_expr)) { + PrintExpr_(start_expr, stream); + stream << ", "; + } + PrintExpr_(stop_expr, stream); + if (!unroll_factor.empty()) { + stream << ", unroll=" << unroll_factor; + } else if (!use_range_constexpr) { + stream << ", unroll_full=True"; + } + stream << "):\n"; + int for_scope = BeginScope(); + PrintStmt_(op->body); + EndScope(for_scope); +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const IfThenElseNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + stream << "if " << RemoveOutermostParentheses(cond) << ":\n"; + int then_scope = BeginScope(); + if (const CallNode *call = op->condition.as(); + call && call->op.same_as(tl::tl_shuffle_elect())) { + PrintIndent(); + stream << "with cute.arch.elect_one():\n"; + int with_scope = BeginScope(); + PrintStmt_(op->then_case); + EndScope(with_scope); + } else { + PrintStmt_(op->then_case); + } + EndScope(then_scope); + + if (op->else_case) { + PrintIndent(); + stream << "else:\n"; + int else_scope = BeginScope(); + PrintStmt_(op->else_case.value()); + EndScope(else_scope); + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const EvaluateNode *op) { + if (is_const_int(op->value)) + return; + const CallNode *call = op->value.as(); + if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { + LOG(FATAL) << "Currently unsupported op: " << call->op; + } + if (call && (call->op.same_as(tvm::tl::device_assert()))) { + std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0])); + PrintIndent(); + stream << "assert " << cond << "\n"; + } else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) { + std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0])); + std::string msg_expr = PrintExpr_(call->args[1]); + PrintIndent(); + stream << "assert " << cond << ", " << msg_expr << "\n"; + } else if (call && call->op.same_as(builtin::tvm_storage_sync())) { + PrintStorageSync_(call); + } else { + CodeGenTileLangPY::VisitStmt_(op); + } +} + +void CodeGenTileLangCuTeDSL::PrintVecElemLoad_(const std::string &vec, + DataType t, int i, + std::ostream &os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + os << vec << "[" << i << "]"; +} + +void CodeGenTileLangCuTeDSL::PrintVecElemStore_(const std::string &vec, + DataType t, int i, + const std::string &value) { + PrintIndent(); + stream << vec << "[" << i << "] = " << value << "\n"; +} + +void CodeGenTileLangCuTeDSL::PrintVecStore_(const BufferNode *buffer, + DataType t, PrimExpr base, + const std::string &value) { + ICHECK(!t.is_scalar()) << "PrintVecStore_() should not be used for scalar"; + + std::string ref = GetBufferRef_(t, buffer, base); + PrintIndent(); + stream << ref << ".store(" << value << ")\n"; +} + +void CodeGenTileLangCuTeDSL::PrintVecBinaryOp_(const std::string &opstr, + DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + // Declare the result. + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << dtype.lanes() << ",), "; + PrintType(dtype.element_of(), stream); + stream << ")\n"; + + std::string vlhs = SSAGetID(PrintExpr_(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr_(rhs), rhs.dtype()); + + const std::string one_char_op{"+-*%<>^|&"}; + const std::string two_char_op{"// == != <= >="}; + if ((opstr.size() == 1 && one_char_op.find(opstr) != std::string::npos) || + (opstr.size() == 2 && two_char_op.find(opstr) != std::string::npos)) { + PrintIndent(); + stream << sret << ".store(" << vlhs << " " << opstr << " " << vrhs << ")\n"; + } else { + // Unpack into individual ops. + for (int i = 0, lanes = dtype.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(opstr[0])) { + value_temp << opstr << "("; + PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp); + value_temp << opstr; + PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore_(sret, dtype, i, value_temp.str()); + } + } + os << sret << ".load()"; +} + +void CodeGenTileLangCuTeDSL::PrintBinaryExpr_(const std::string &opstr, + DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + if (dtype.is_scalar()) { + CodeGenTileLangPY::PrintBinaryExpr_(opstr, dtype, lhs, rhs, os); + } else { + PrintVecBinaryOp_(opstr, dtype, lhs, rhs, os); + } +} + +void CodeGenTileLangCuTeDSL::PrintBinaryIntrinsic_( + const CallNode *op, const char *opstr, + std::ostream &os) { // NOLINT(*) + if (op->dtype.is_scalar()) { + CodeGenTileLangPY::PrintBinaryIntrinsic_(op, opstr, os); + } else { + PrintVecBinaryOp_(opstr, op->dtype, op->args[0], op->args[1], os); + } +} + +void CodeGenTileLangCuTeDSL::PrintCallExtern_(Type ret_type, + ffi::String global_symbol, + const ffi::Array &args, + bool skip_first_arg, + std::ostream &os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + + std::string global_symbol_str = global_symbol; + ReplaceAll(global_symbol_str, "::", "."); + + std::vector sargs; + // when the template arguments occurs at the end, merge them with function + // arguments + if (global_symbol_str.back() == '>') { + auto pos = global_symbol_str.rfind('<'); + ICHECK(pos != std::string::npos); + std::string template_args = + global_symbol_str.substr(pos + 1, global_symbol_str.size() - pos - 2); + ReplaceAll(template_args, "true", "True"); + ReplaceAll(template_args, "false", "False"); + sargs.push_back(template_args); + + global_symbol_str.resize(pos); + } + const size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string sarg = PrintExpr_(args[i]); + if (ret_dtype.is_fixed_length_vector()) { + std::string val = SSAGetID(sarg, args[i].dtype()); + sargs.push_back(std::move(val)); + } else { + sargs.push_back(sarg); + } + } + + // Replace "<...>" with "(...)". Nested "<" is not supported + { + auto pos_left = global_symbol_str.find('<'); + while (pos_left != std::string::npos) { + auto pos_right = global_symbol_str.find('>', pos_left + 1); + if (pos_right != std::string::npos) { + auto args = + global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1); + ReplaceAll(args, "true", "True"); + ReplaceAll(args, "false", "False"); + global_symbol_str.replace(pos_left, args.size() + 2, "(" + args + ")"); + } + pos_left = global_symbol_str.find('<'); + } + } + + // Special cases: + // Map C math functions to Python/cutedsl equivalents + const auto canonicalized_global_symbol_str = + CanonicalizeFastmathFunctionName_(global_symbol_str); + const bool canonicalized = !canonicalized_global_symbol_str.empty(); + if (canonicalized) { + global_symbol_str = canonicalized_global_symbol_str; + } + + // Atomic Functions + if (global_symbol_str.substr(0, 6) == "Atomic") { + global_symbol_str = "tl." + global_symbol_str; + // Convert first argument (Buffer) to pointer for atomic operations + if (const BufferLoadNode *load = args[arg_begin].as()) { + ICHECK_EQ(load->indices.size(), 1) + << "CodeGenTileLangCuTeDSL only supports flat memory"; + sargs[0] = GetBufferPtr_(load->buffer.get(), load->indices[0]); + } + } + // some optional template arguments might be ommited, so add names explicitly + // for remain arguments + if (global_symbol_str == "tl.gemm_ss" || global_symbol_str == "tl.gemm_rs" || + global_symbol_str == "tl.gemm_sr" || global_symbol_str == "tl.gemm_rr") { + ICHECK(sargs.size() >= 3); + sargs[sargs.size() - 3] = "A_ptr=" + sargs[sargs.size() - 3]; + sargs[sargs.size() - 2] = "B_ptr=" + sargs[sargs.size() - 2]; + sargs[sargs.size() - 1] = "C_ptr=" + sargs[sargs.size() - 1]; + } + + if (ret_dtype.is_fixed_length_vector()) { + // maybe simplify this if TensorSSA suppports this OP + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << ret_dtype.lanes() << ",), "; + PrintType(ret_dtype.element_of(), stream); + stream << ")\n"; + + // Emit a scalar call for each lane. + bool has_template_arg = (sargs.size() > args.size() - arg_begin); + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol_str << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j != 0) { + scall << ", "; + } + + if (j == 0 && has_template_arg) { + scall << sargs[j]; + } else { + PrintVecElemLoad_( + sargs[j], + args[arg_begin + j - static_cast(has_template_arg)] + .dtype(), + i, scall); + } + } + if (canonicalized && enable_fastmath_) { + if (!sargs.empty()) { + scall << ", "; + } + scall << "fastmath=True"; + } + scall << ")"; + PrintVecElemStore_(sret, ret_dtype, i, scall.str()); + } + os << sret << ".load()"; + } else { + os << global_symbol_str << "("; + for (size_t i = 0; i < sargs.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << sargs[i]; + } + if (canonicalized && enable_fastmath_) { + if (!sargs.empty()) { + os << ", "; + } + os << "fastmath=True"; + } + os << ")"; + } +} + +std::string CodeGenTileLangCuTeDSL::GetBufferPtr_(const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + const std::string vid = GetVarID(buffer_var); + + DataType buffer_element_dtype = buffer->dtype; + bool is_handle_type_match = + HandleTypeMatch_(buffer_var, buffer_element_dtype); + std::string ptr_str; + if (is_handle_type_match) { + ptr_str = vid + ".iterator"; + } else { + ptr_str = "tl.recast_ptr(" + vid + + ".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")"; + } + + std::string index_str = PrintExpr_(index); + return "(" + ptr_str + " + " + index_str + ")"; +} + +// The following forms can be returned: +// (1) vid +// (2) vid[i] +// (3) tl.make_tensor_at_offset(...)[0] +// (4) tl.make_tensor_at_offset(...) +// +// Form (4) is needed when the whole tensor is loaded or stored. +// It's the only form that ends with ")". Using this fact, BufferLoadNode will +// add ".load()" and BufferStoreNode will add ".store()". +std::string CodeGenTileLangCuTeDSL::GetBufferRef_(DataType t, + const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + std::string vid = GetVarID(buffer_var); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + if (scope == "local.var" || scope.find("local.descriptor") == 0) { + return vid; + } + + DataType buffer_element_dtype = buffer->dtype; + bool is_handle_type_match = + HandleTypeMatch_(buffer_var, buffer_element_dtype); + std::string ptr_str; + if (is_handle_type_match) { + ptr_str = vid + ".iterator"; + } else { + ptr_str = "tl.recast_ptr(" + vid + + ".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")"; + } + + const std::string index_str = PrintExpr_(index); + + if (t == buffer_element_dtype) { + if (is_handle_type_match && buffer_element_dtype.is_scalar() && + (scope == "local" || scope == "shared" || scope == "shared.dyn" || + scope == "shared.barrier")) { + // Tensors in these scopes are allocated as one-dimensional, so can be + // assessed via "[]" correctly. Other tensors may be multi-dimensional, + // and must be assessed via ptr, otherwise CuTeDSL will interpret "[]" + // access using its visiting order and layout. + return vid + "[" + index_str + "]"; + } else { + std::ostringstream os; + os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str + << ", (1,), div_by=" << buffer_element_dtype.lanes() << ")"; + // for vector data types, ".load()" (added by BufferLoadNode) is neeed + // instead of "[0]" + if (buffer_element_dtype.is_scalar()) { + os << "[0]"; + } + return os.str(); + } + } else { + const int num = t.bits() * t.lanes(); + const int den = buffer_element_dtype.bits() * buffer_element_dtype.lanes(); + ICHECK_EQ(num % den, 0) << "Cannot form view: bitwidth not divisible"; + int buffer_size = num / den; + + std::ostringstream os; + os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str << ", (" + << buffer_size << ",), div_by=" << buffer_size << ")"; + return os.str(); + } +} + +void CodeGenTileLangCuTeDSL::BindThreadIndex_(const IterVar &iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + + auto &thread_tag = iv->thread_tag; + ICHECK(thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || + thread_tag == "threadIdx.z" || thread_tag == "blockIdx.x" || + thread_tag == "blockIdx.y" || thread_tag == "blockIdx.z"); + + // cute.arch.thread_idx() and block_idx() are Int32 + DataType from_dtype = DataType::Int(32); + var_idmap_[iv->var.get()] = + CastFromTo_(thread_tag, from_dtype, iv->var.dtype()); +} + +void CodeGenTileLangCuTeDSL::PrintStorageSync_(const CallNode *op) { + auto args = op->args; + const std::string &sync = args[0].as()->value; + if (sync == "warp") { + // do nothing + } else if (sync == "shared" || sync == "shared.dyn") { + PrintIndent(); + if (args.size() == 1) { + stream << "tl.sync_threads()\n"; + } else if (args.size() == 2) { + auto barrier_id_ptr = args[1].as(); + ICHECK(barrier_id_ptr) + << "storage_sync barrier_id (args[1]) must be IntImm, got " + << args[1]->GetTypeKey(); + auto barrier_id = barrier_id_ptr->value; + stream << "tl.sync_thread_partial(" << barrier_id << ")\n"; + } else if (args.size() == 3) { + auto barrier_id_ptr = args[1].as(); + ICHECK(barrier_id_ptr) + << "storage_sync barrier_id (args[1]) must be IntImm, got " + << args[1]->GetTypeKey(); + auto thread_count_ptr = args[2].as(); + ICHECK(thread_count_ptr) + << "storage_sync thread_count (args[2]) must be IntImm, got " + << args[2]->GetTypeKey(); + auto barrier_id = barrier_id_ptr->value; + auto thread_count = thread_count_ptr->value; + stream << "tl.sync_thread_partial(" << barrier_id << ", " << thread_count + << ")\n"; + } else { + LOG(FATAL) << "Invalid number of arguments for storage sync: " + << args.size(); + } + } else if (sync == "global") { + LOG(FATAL) << "PrintStorageSync_ for global is not supported for now"; + } else { + LOG(FATAL) << "Unknown storage sync scope: " << sync; + } +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_cutedsl.h b/tilelang/original/src/target/codegen_cutedsl.h new file mode 100644 index 0000000000000000000000000000000000000000..1d4edc5382e1cdbbd5bbc1b569cddb4f491fca5e --- /dev/null +++ b/tilelang/original/src/target/codegen_cutedsl.h @@ -0,0 +1,102 @@ +/*! + * \file target/codegen_cutedsl.h + * \brief Utility to generate CuTeDSL code + */ +#ifndef TVM_TL_TARGET_CODEGEN_CUTEDSL_H_ +#define TVM_TL_TARGET_CODEGEN_CUTEDSL_H_ + +#include +#include +#include + +#include +#include +#include + +#include "codegen_py.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangCuTeDSL final : public CodeGenTileLangPY { +public: + CodeGenTileLangCuTeDSL(); + +protected: + void PrintFuncDecorator_(std::ostream &os) override; // NOLINT(*) + void PreFunctionBody_(const PrimFunc &f) override; + +protected: + void PrintType(DataType t, std::ostream &os) override; // NOLINT(*) + + void VisitExpr_(const BroadcastNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode *op, + std::ostream &os) override; // NOLINT(*) + + void VisitStmt_(const BufferStoreNode *op) override; + void VisitStmt_(const AllocateNode *op) override; + void VisitStmt_(const AttrStmtNode *op) override; + void VisitStmt_(const ForNode *op) override; + void VisitStmt_(const IfThenElseNode *op) override; + void VisitStmt_(const EvaluateNode *op) override; + +protected: + virtual void PrintVecElemLoad_(const std::string &vec, DataType t, int i, + std::ostream &os); // NOLINT(*) + virtual void PrintVecElemStore_(const std::string &vec, DataType t, int i, + const std::string &value); + virtual void PrintVecStore_(const BufferNode *buffer, DataType t, + PrimExpr base, const std::string &value); + void PrintVecBinaryOp_(const std::string &opstr, DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os); // NOLINT(*) + void PrintBinaryExpr_(const std::string &opstr, DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) override; // NOLINT(*) + void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr, + std::ostream &os) override; // NOLINT(*) + + void PrintCallExtern_(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, bool skip_first_arg, + std::ostream &os) override; // NOLINT(*) + + std::string GetBufferPtr_(const BufferNode *buffer, PrimExpr index); + std::string GetBufferRef_(DataType t, const BufferNode *buffer, + PrimExpr index) override; + + /*! + * \brief Print expr representing the thread tag + * \param IterVar iv The thread index to be binded; + */ + virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*) + + virtual void PrintStorageSync_(const CallNode *op); + + std::string + CanonicalizeFastmathFunctionName_(const std::string &func_name) const; + +private: + // The name of the mbarrier array in shared memory + const std::string mbarrier_name_ = "mbarrier"; + + std::unordered_map unroll_factor_; + + std::vector eviction_policy_names_ = { + "EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"}; + + // Fastmath configuration (read from PassContext) + bool enable_fastmath_ = false; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TL_TARGET_CODEGEN_CUTEDSL_H_ diff --git a/tilelang/original/src/target/codegen_hip.cc b/tilelang/original/src/target/codegen_hip.cc new file mode 100644 index 0000000000000000000000000000000000000000..db9b0e40874cbdc689fb58d54cf9b20e984b3b4c --- /dev/null +++ b/tilelang/original/src/target/codegen_hip.cc @@ -0,0 +1,1448 @@ +/*! + * \file target/codegen.cc + */ + +#include "codegen_hip.h" +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "target/source/ptx.h" + +namespace tvm { +namespace codegen { + +static std::string GetFP8Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "_2"; + } else if (lanes == 4) { + vec = "_4"; + } else if (lanes == 8) { + vec = "_8"; + } else if (lanes == 16) { + vec = "_16"; + } else { + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) " + "for FP8"; + } + if (type.code() == DataType::kFloat8_e4m3fn) { + stream << "fp8_e4" << vec << "_t"; + } else if (type.code() == DataType::kFloat8_e4m3fnuz) { + stream << "fp8_e4" << vec << "_t"; + } else if (type.code() == DataType::kFloat8_e4m3) { + stream << "fp8_e4" << vec << "_t"; + } else if (type.code() == DataType::kFloat8_e4m3b11fnuz) { + stream << "fp8_e4" << vec << "_t"; + } else if (type.code() == DataType::kFloat8_e5m2) { + stream << "fp8_e5" << vec << "_t"; + } else if (type.code() == DataType::kFloat8_e5m2fnuz) { + stream << "fp8_e5" << vec << "_t"; + } else if (type.code() == DataType::kFloat8_e8m0fnu) { + stream << "fp8_e8" << vec << "_t"; + } else { + LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type; + } + return stream.str(); +} + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { +public: + void register_rule(const std::string &pattern, + const std::string &replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto &&rule : _rules) { + auto [pattern, replacement] = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } + } + return str; + } + void empty_rules() { _rules.clear(); } + +private: + std::vector> _rules; +}; + +CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; } + +void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream &os) { + os << "extern \"C\" __global__ "; +} + +class LaunchConfigExtractor : public tir::StmtVisitor { +private: + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->var->name_hint == "threadIdx.x" || + iv->thread_tag == "threadIdx.x") { + threadIdx_x_ext = op->value; + } else if (iv->var->name_hint == "threadIdx.y" || + iv->thread_tag == "threadIdx.y") { + threadIdx_y_ext = op->value; + } else if (iv->var->name_hint == "threadIdx.z" || + iv->thread_tag == "threadIdx.z") { + threadIdx_z_ext = op->value; + } + } + StmtVisitor::VisitStmt_(op); + } + +public: + PrimExpr threadIdx_x_ext = Integer(1); + PrimExpr threadIdx_y_ext = Integer(1); + PrimExpr threadIdx_z_ext = Integer(1); +}; + +void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) { + LaunchConfigExtractor extractor; + extractor(f->body); + arith::Analyzer analyzer; + PrimExpr threadIdx_ext = + analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * + extractor.threadIdx_z_ext); + if (const IntImmNode *const threadIdx_ext_int = + threadIdx_ext.as()) { + if (threadIdx_ext_int->value == 1) { + // unable to extract the number of threads per block, hence directly + // return + return; + } + stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + } +} + +std::string CodeGenTileLangHIP::Finish() { + // hip must need a header file. + decl_stream << "#define HIP_ENABLE_WARP_SYNC_BUILTINS\n"; + decl_stream << "#include \n"; + if (need_mma_h_) { + decl_stream << "#include \n"; + } + + if (enable_fp8_) { + decl_stream << "#include \n"; + } + + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "\n"; + return CodeGenC::Finish(); +} + +void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode *op) { + if (op->kind == tir::ForKind::kUnrolled) { + PrintIndent(); + stream << "#pragma unroll\n"; + } + std::string extent = + PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + std::string start = PrintExpr(op->min); + stream << "for ("; + PrintType(op->loop_var.dtype(), stream); + stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent + << "; ++" << vid << ") {\n"; + int for_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(for_scope); + PrintIndent(); + stream << "}\n"; +} + +void CodeGenTileLangHIP::BindThreadIndex(const IterVar &iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + var_idmap_[iv->var.get()] = + CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); +} + +void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK(t.is_scalar()) << "do not yet support vector types"; + os << "void*"; + return; + } + + if (t.is_void()) { + os << "void"; + return; + } + + if (t == tl::cuTensorMapType()) { + os << "CUtensorMap"; + return; + } + + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: + if (t.is_scalar()) { + os << "half_t"; + } else if (lanes <= 8) { + // Emit CUDA code to access fp16 vector elements. + // + // half4 is stored as uint2 + // + // h4.x is emitted as *(half2*)(&(u2.x)).x + // h4.y is emitted as *(half2*)(&(u2.x)).y + // h4.z is emitted as *(half2*)(&(u2.y)).x + // h4.w is emitted as *(half2*)(&(u2.y)).y + // + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else { + fail = true; + } + break; + case 32: + if (lanes <= 4) { + os << "float"; + } else if (lanes <= 8) { + // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8. + // + // float8 is stored as ulonglong4 + // + // f8.v1 is emitted as *(float2*)(&(ul4.x)).x + // f8.v2 is emitted as *(float2*)(&(ul4.x)).y + // + ICHECK_EQ(lanes % 2, 0) + << "only support even lane for float type with lanes > 4"; + os << "ulonglong" << lanes / 2; + } else { + fail = true; + } + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; + } + if (!fail && (t.is_scalar() || t.bits() == 16)) + return; + if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) + return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_bfloat16()) { + if (t.is_scalar()) { + os << "bfloat16_t"; + } else if (lanes <= 8) { + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else { + fail = true; + } + if (!fail) + return; + } else if (t.is_float8()) { + enable_fp8_ = true; + os << GetFP8Type(t); + return; + } else if (t == DataType::Bool()) { + os << "bool"; + return; + } else if (t.is_vector_bool()) { + // CUDA does not support bool vectors. + // Use ushort vectors to represent instead. + int n = t.lanes(); + if (n <= 4) { + os << "ushort" << n; + return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << "u"; + } + switch (t.bits()) { + case 1: { + if (t.is_scalar()) { + os << "int"; + return; + } else if (t.lanes() == 8) { + os << "int8_t"; + return; + } else if (t.lanes() == 16) { + os << "int16_t"; + return; + } else if (t.lanes() == 32) { + os << "int"; + return; + } else { + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + } + } + case 4: { + if (t.is_scalar()) { + os << "int"; + return; + } else if (t.lanes() == 4) { + os << "int16_t"; + return; + } else if (t.lanes() == 8) { + // directly 8 4-bit int in integer. + os << "int"; + return; + } else if (t.lanes() == 16) { + os << "int2"; + return; + } else if (t.lanes() == 32) { + os << "int4"; + return; + } else if (t.lanes() == 64) { + os << "int8"; + return; + } else { + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + } + } + case 8: { + if (t.lanes() == 4) { + // directly 4 8 bit int in integer. + + // We use int for int8x4 instead of char4 because using char4 is + // likely to produce extra instructions to pack four int8 elements + // into 32-bit data. + os << "int"; + return; + } else if (t.lanes() == 8) { + os << "int2"; + return; + } else if (t.lanes() == 16) { + os << "int4"; + return; + } else if (!t.is_uint() && t.is_scalar()) { + os << "signed char"; + break; + } else { + os << "char"; + break; + } + } + case 16: { + if (t.is_scalar()) { + os << "short"; + } else if (t.lanes() <= 4) { + os << "short" << lanes; + } else if (t.lanes() <= 8) { + // Emit CUDA code to access int16 vector elements. + // + // short4 is stored as int2 + // + // s4.x is emitted as *(short2*)(&(i2.x)).x + // s4.y is emitted as *(short2*)(&(i2.x)).y + // s4.z is emitted as *(short2*)(&(i2.y)).x + // s4.w is emitted as *(short2*)(&(i2.y)).y + // + ICHECK_EQ(t.lanes() % 2, 0) + << "only support even lane for shorT type with lanes > 4"; + os << "int" << t.lanes() / 2; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + case 32: { + if (t.is_scalar()) { + os << "int"; + } else if (t.lanes() <= 4) { + os << "int" << t.lanes(); + } else if (t.lanes() <= 8) { + // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. + // + // int8 is stored as longlong4 + // + // i8.v1 is emitted as *(int2*)(&(l4.x)).x + // i8.v2 is emitted as *(int2*)(&(l4.x)).y + // + ICHECK_EQ(lanes % 2, 0) + << "only support even lane for int32 type with lanes > 4"; + os << "longlong" << lanes / 2; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + case 64: { + if (t.is_scalar()) { + os << "int64_t"; + } else if (t.lanes() == 2) { + os << "longlong2"; + } else if (t.lanes() == 3) { + os << "longlong3"; + } else if (t.lanes() == 4) { + os << "longlong4"; + } + return; + } + default: + fail = true; + break; + } + if (!fail && lanes == 1) { + return; + } + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; +} + +void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string &op, DataType t, + PrimExpr lhs, PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + // Declare the result. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + int ssa_scope = BeginScope(); + { + // Unpack into individual ops. + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << op; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore(sret, t, i, value_temp.str()); + } + } + EndScope(ssa_scope); + os << sret; +} + +void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t, + int i, + std::ostream &os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + + static const char access[] = {'x', 'y', 'z', 'w'}; + ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 + : (t.bits() == 16 || t.bits() == 32) ? 8 + : 4)); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + std::string type_name = t.is_int() ? "char" : "unsigned char"; + if (t.lanes() == 2 || t.lanes() == 3) { + os << vec << "." << access[i % t.lanes()]; + } else { + std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); + os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; + } + } else if (t.is_float16()) { + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; + } else if (t.is_bfloat16()) { + os << "((bfloat16x2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; + } else if (t.lanes() > 4 && t.lanes() <= 8) { + std::string type_name; + if (t.bits() == 16) { + if (t.is_int()) { + type_name = "short"; + } else if (t.is_uint()) { + type_name = "ushort"; + } + } else if (t.bits() == 32) { + if (t.is_int()) { + type_name = "int"; + } else if (t.is_uint()) { + type_name = "uint"; + } else if (t.is_float()) { + type_name = "float"; + } + } + ICHECK(!type_name.empty()); + os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] + << ")))->" << access[i % 2]; + } else { + os << vec << "." << access[i]; + } +} + +void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t, + int i, const std::string &value) { + this->PrintIndent(); + static const char access[] = {'x', 'y', 'z', 'w'}; + ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 + : (t.bits() == 16 || t.bits() == 32) ? 8 + : 4)); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (t.lanes() == 2 || t.lanes() == 3) { + stream << vec << '.' << access[i % t.lanes()] << "=" + << "(" << value << ");\n"; + } else { + std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); + stream << ac << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; + } + stream << "(" << value << " << " << i % 4 * 8 << ");\n"; + } + } else if (t.is_float16()) { + stream << "*((half_t*)(&(((half2*)(&(" << vec << "." << access[i / 2] + << ")))->" << access[i % 2] << "))) = " << value << ";\n"; + } else if (t.is_bfloat16()) { + stream << "((bfloat16_t*)(&(" << vec << "." << access[i / 2] << ")))[" + << (i % 2) << "] = " << value << ";\n"; + } else if (t.lanes() > 4 && t.lanes() <= 8) { + std::string type_name; + if (t.bits() == 16) { + if (t.is_int()) { + type_name = "short"; + } else if (t.is_uint()) { + type_name = "ushort"; + } + } else if (t.bits() == 32) { + if (t.is_int()) { + type_name = "int"; + } else if (t.is_uint()) { + type_name = "uint"; + } else if (t.is_float()) { + type_name = "float"; + } + } + ICHECK(!type_name.empty()); + stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] + << ")))->" << access[i % 2] << " = " << value << ";\n"; + } else { + stream << vec << "." << access[i] << " = " << value << ";\n"; + } +} + +void CodeGenTileLangHIP::PrintStorageSync(const CallNode *op) { + const std::string &sync = op->args[0].as()->value; + if (sync == "warp") { + // DO nothing. + } else if (sync == "shared" || sync == "shared.dyn") { + this->PrintIndent(); + this->stream << "__syncthreads();\n"; + } +} + +void CodeGenTileLangHIP::PrintStorageScope(const std::string &scope, + std::ostream &os) { // NOLINT(*) + ICHECK_NE(scope, "global") + << "Cannot allocate global memory when targeting CUDA. You must pass " + "all global arrays as input instead"; + if (scope == "shared") { + os << "__shared__ "; + } else if (scope == "shared.dyn") { + os << "extern __shared__ __align__(1024) "; + } +} + +std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, + DataType target) { + if (from == target) + return value; + std::ostringstream os; + os << "(("; + this->PrintType(target, os); + os << ")"; + if (from.is_float16() && (target.is_int() || target.is_uint()) && + target.bits() == 8) { + os << "("; + if (target.is_uint()) { + os << "u"; + } + os << "int)"; + } + os << value << ")"; + return os.str(); +} + +void CodeGenTileLangHIP::VisitExpr_(const CastNode *op, std::ostream &os) { + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); + + // Emit simple C-style type conversion. + if (from_ty.is_scalar()) + return CodeGenC::VisitExpr_(op, os); + + // We could emit make_float4 like calls, but the emitted code looks + // too compact to read. Emit this as vectorized unary ops. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(target_ty, stream); + stream << ' ' << sret << ";\n"; + { + std::string src = SSAGetID(PrintExpr(op->value), from_ty); + for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { + std::ostringstream val; + val << "("; + PrintType(target_ty.element_of(), val); + val << ")("; + PrintVecElemLoad(src, from_ty, i, val); + val << ")"; + PrintVecElemStore(sret, target_ty, i, val.str()); + } + } + os << sret; +} + +void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, + const Array &args, + bool skip_first_arg, + std::ostream &os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + if (ret_dtype.is_vector()) { + // + // Emit an unsupported vector call + // + // v = intrin_f((float4*)A[0], (float4*)B[0]) + // + // as + // + // float4 __ret; + // { + // float4 __arg0 = ((float4*)A)[0]; + // float4 __arg1 = ((float4*)B)[0]; + // __ret.x = intrin_f(__arg0.x, __arg1.x); + // __ret.y = intrin_f(__arg0.y, __arg1.y); + // __ret.z = intrin_f(__arg0.z, __arg1.z); + // __ret.w = intrin_f(__arg0.w, __arg1.w); + // } + // v = __ret; + // + // Declare the result vector. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(ret_dtype, stream); + stream << ' ' << sret << ";\n"; + { + // Load arguments. + std::vector sargs; + size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + sargs.push_back(std::move(val)); + } + + // Emit a scalar call for each lane. + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j > 0) + scall << ", "; + PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + } + scall << ")"; + PrintVecElemStore(sret, ret_dtype, i, scall.str()); + } + } + os << sret; + } else { + CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, + os); + } +} + +// Print a reference expression to a buffer. +std::string CodeGenTileLangHIP::GetBufferRef(DataType t, + const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + std::ostringstream os; + std::string vid = GetVarID(buffer_var); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + // bool is_vol = IsVolatile(buffer_var); + // always false for tl cutlass backend. + bool is_vol = false; + + auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { + std::ostringstream ptr_os; + ptr_os << "("; + if (is_vol) { + ptr_os << "volatile "; + } + if (!scope.empty() && IsScopePartOfType()) { + PrintStorageScope(scope, ptr_os); + } + PrintType(pointed_to, ptr_os); + ptr_os << "*)"; + return ptr_os.str(); + }; + + DataType buffer_element_dtype = buffer->dtype; + + std::string buffer_str = vid; + if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { + std::stringstream temp; + temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; + buffer_str = temp.str(); + } + + std::string index_str = PrintExpr(index); + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { + // This is a special case, because CodegenCUDA::PrintType() + // returns "int" for bool and for 4-bit integers. In most cases, + // we divide by the number of lanes to determine the index. + // However, the backing type for scalar int4 and scalar bool is + // int32. Therefore, we need to divide by the ratio of their + // sizes in that case. + int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); + + os << "*(" + << "(" << ptr_cast(t) << vid << ")" + << " + " << index_str << " / " << div_factor << ")"; + } else if (t == buffer_element_dtype) { + os << buffer_str << "[" << index_str << "]"; + } else { + os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; + } + + return os.str(); +} + +void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { + auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) { + this->PrintIndent(); + this->stream << name << "("; + for (size_t i = offset; i < op->args.size(); i++) { + if (i > offset) + this->stream << ", "; + this->stream << this->PrintExpr(op->args[i]); + } + this->stream << ");\n"; + }; + if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = this->PrintExpr(op->args[0]); + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + // use size of argument list to indicate whether or not to use predicated + // cp.async + if (op->args.size() == 5) { + this->PrintIndent(); + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" + << dst_offset << ", " << src << "+" << src_offset << ");\n"; + } else { + std::string condition = this->PrintExpr(op->args[5]); + this->PrintIndent(); + this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst + << "+" << dst_offset << ", " << src << "+" << src_offset + << ", " << condition << ");\n"; + } + } else if (op->op.same_as(builtin::ptx_commit_group())) { + print_extern_call_stmt("tl::cp_async_commit"); + } else if (op->op.same_as(builtin::ptx_wait_group())) { + int n = Downcast(op->args[0])->value; + std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; + print_extern_call_stmt(func_name, 1); + } else if (op->op.same_as(builtin::create_barriers())) { + this->PrintIndent(); + int barrier_count = Downcast(op->args[0])->value; + std::string barrier_name = "_mbarrier"; + this->stream << "__shared__ uint64_t " << barrier_name << "[" + << barrier_count << "];\n"; + } else if (op->op.same_as(tl::get_mbarrier())) { + std::string barrier_name = "_mbarrier"; + std::string barrier_id = this->PrintExpr(op->args[0]); + os << barrier_name + "[" + barrier_id + "]"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + print_extern_call_stmt("tl::mbarrier_arrive"); + } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { + print_extern_call_stmt("tl::mbarrier_init"); + } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { + print_extern_call_stmt("tl::mbarrier_arrive_expect_tx"); + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); + } else if (op->op.same_as(tl::mbarrier_expect_tx())) { + print_extern_call_stmt("tl::mbarrier_expect_tx"); + } else if (op->op.same_as(tl::mbarrier_wait_parity())) { + print_extern_call_stmt("tl::mbarrier_wait"); + } else if (op->op.same_as(tl::ptx_stmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::wait_wgmma())) { + this->PrintIndent(); + int num_mma = Downcast(op->args[0])->value; + this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; + } else if (op->op.same_as(tl::pack_b16())) { + os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " + << this->PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::__ldg())) { + // HIP fallback: regular load + const BufferLoadNode *bl = op->args[0].as(); + ICHECK(bl) << "T.__ldg expects a BufferLoad as the first argument."; + ICHECK_EQ(bl->indices.size(), 1) + << "T.__ldg currently supports flattened 1D buffer accesses."; + const BufferNode *buffer = bl->buffer.get(); + PrimExpr base = bl->indices[0]; + auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base); + os << buffer_ref; + } else if (op->op.same_as(builtin::tvm_fill_fragment())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 6U); + os << "nvcuda::wmma::fill_fragment("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::load_matrix_sync("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[6], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::store_matrix_sync("; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[6], os); + if (const StringImmNode *str = op->args[7].as()) { + os << ", nvcuda::wmma::mem_" << str->value; + } else { + LOG(FATAL) << "Invalid parameters"; + } + os << ")"; + } else if (op->op.same_as(builtin::tvm_mma_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::mma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", " : ")"); + } + } else if (op->op.same_as(builtin::tvm_bmma_sync())) { + need_mma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::bmma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", " : ")"); + } + } else if (op->op.same_as(tl::tvm_mfma())) { + // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype} + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: float16, float32, ... + // arg 4: B precision: float16, float32, ... + // arg 5: C precision: float32, float64, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + + ICHECK(op->args.size() == 12U) + << "Invalid number of arguments for tvm_mfma"; + std::string prefix = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + ICHECK(A_layout == "row" || B_layout == "row") + << "Matrix core only support row major"; + // map for dtype -> float32x4 -> float4 + std::unordered_map dtype_map = { + {"int8", "char"}, + {"int32", "int"}, + {"int8x4", "int32_t"}, + {"int8x8", "int64_t"}, + {"int32x4", "int32x4"}, + {"float16", "half"}, + {"float32", "float"}, + {"float64", "double"}, + {"float16x4", "float16x4"}, + {"bfloat16x4", "bfloat16x4_vec"}, + {"float32x4", "float32x4"}, + {"float8_e4m3fnuzx4", "fp8_e4_4_t"}, + {"float8_e4m3fnuzx8", "long"}, + {"float32x16", "float32x16"}}; + std::string call_mfma_code = R"({ + *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), + *((({B_dtype}*){b_ref}) + {b_bias}), + *((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0); + })"; + std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix; + Replacer replacer; + + replacer.register_rule("{mfma_buildin}", mfma_buildin); + replacer.register_rule("{A_dtype}", dtype_map[A_dtype]); + replacer.register_rule("{B_dtype}", dtype_map[B_dtype]); + replacer.register_rule("{C_dtype}", dtype_map[C_dtype]); + replacer.register_rule("{a_ref}", a_ref); + replacer.register_rule("{a_bias}", a_bias); + replacer.register_rule("{b_ref}", b_ref); + replacer.register_rule("{b_bias}", b_bias); + replacer.register_rule("{c_ref}", c_ref); + replacer.register_rule("{c_bias}", c_bias); + os << replacer.rewrite(call_mfma_code); + } else if (op->op.same_as(tl::tvm_mmac())) { + // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype} + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: float16, float32, ... + // arg 4: B precision: float16, float32, ... + // arg 5: C precision: float32, float64, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + + ICHECK(op->args.size() == 12U) + << "Invalid number of arguments for tvm_mmac"; + std::string prefix = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + ICHECK(A_layout == "row" || B_layout == "row") + << "Matrix core only support row major"; + // map for dtype -> float32x4 -> float4 + std::unordered_map dtype_map = { + {"int8", "char"}, + {"int32", "int"}, + {"int8x4", "int32_t"}, + {"int8x8", "int64_t"}, + {"int32x4", "int32x4"}, + {"float16", "half"}, + {"float32", "float"}, + {"float64", "double"}, + {"float16x4", "float16x4"}, + {"bfloat16x4", "bfloat16x4"}, + {"float32x4", "float32x4"}, + {"float8_e4m3fnuzx4", "fp8_e4_4_t"}, + {"float8_e4m3fnuzx8", "long"}, + {"float32x16", "float32x16"}}; + std::string call_mmac_code = R"({ + *((({C_dtype}*){c_ref}) + {c_bias}) = {mmac_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), + *((({B_dtype}*){b_ref}) + {b_bias}), + *((({C_dtype}*){c_ref}) + {c_bias})); + })"; + std::string mmac_buildin = "__builtin_amdgcn_mmac_" + prefix; + Replacer replacer; + + replacer.register_rule("{mmac_buildin}", mmac_buildin); + replacer.register_rule("{A_dtype}", dtype_map[A_dtype]); + replacer.register_rule("{B_dtype}", dtype_map[B_dtype]); + replacer.register_rule("{C_dtype}", dtype_map[C_dtype]); + replacer.register_rule("{a_ref}", a_ref); + replacer.register_rule("{a_bias}", a_bias); + replacer.register_rule("{b_ref}", b_ref); + replacer.register_rule("{b_bias}", b_bias); + replacer.register_rule("{c_ref}", c_ref); + replacer.register_rule("{c_bias}", c_bias); + os << replacer.rewrite(call_mmac_code); + } else if (op->op.same_as(builtin::thread_return())) { + os << "return"; + } else if (op->op.same_as(tl::tl_gemm())) { + ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::tl_gemm_sp())) { + LOG(FATAL) << "tl_gemm_sp is not supported on HIP"; + } else if (op->op.same_as(tl::loop_break())) { + this->PrintIndent(); + this->stream << "break;\n"; + } else if (op->op.same_as(tl::no_set_max_nreg())) { + // HIP doesn't need explicit register management like CUDA + // This is a no-op for HIP + return; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tir::attr::async_commit_queue_scope) { + const IntImmNode *queue_id = op->value.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + this->VisitStmt(op->body); + auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); + this->VisitExpr(commit_group, this->stream); + return; + } else if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto queue_id = wait_attrs.first.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + auto wait_cnt = wait_attrs.second; + auto wait_group = + Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); + this->VisitExpr(wait_group, this->stream); + auto inner = op->body.as(); + ICHECK(inner); + this->VisitStmt(inner->body); + return; + } else if (op->attr_key == "threadblock_swizzle_pattern") { + this->PrintIndent(); + const StringImmNode *pattern = op->value.as(); + ICHECK(pattern); + this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; + this->VisitStmt(op->body); + return; + } + CodeGenC::VisitStmt_(op); +} + +void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + this->PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + + if (scope == "shared.dyn") { + stream << ' ' << vid << "[];\n"; + } else { + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + stream << ' ' << vid << '[' << constant_size << "];\n"; + } + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); +} + +void CodeGenTileLangHIP::VisitExpr_(const RampNode *op, std::ostream &os) { + int lanes = static_cast(Downcast(op->lanes)->value); + CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; + os << "(make_"; + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < lanes; i++) { + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != lanes - 1) + os << ", "; + } + os << "))"; +} + +void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + int lanes = static_cast(Downcast(op->lanes)->value); + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && + lanes == 4) { + // make_int8x4 + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + return; + } + + if (op->dtype.is_float16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + + if (op->dtype.is_bfloat16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "__pack_bfloat162(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + + if (op->dtype.is_float() && op->dtype.bits() == 32 && + op->dtype.lanes() == 8) { + std::string v = PrintExpr(op->value); + os << "make_ulonglong4("; + for (int i = 0; i < 4; ++i) { + if (i != 0) + os << ", "; + os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + bool fail = false; + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xF; + + if (lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; + if (op->dtype.is_uint()) { + os << "(uint16_t)" << v; + } else { + os << "(int16_t)" << v; + } + } else { + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | + (v << 4) | v; + if (lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (lanes == 16 || lanes == 32) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 8; ++i) { + if (i != 0) + os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + + if (!fail) { + return; + } + } + + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << ')'; +} + +inline void PrintConst(const FloatImmNode *op, std::ostream &os, + CodeGenTileLangHIP *p) { // NOLINT(*) + // Type code is kBFloat + if (op->dtype.is_bfloat16()) { + os << "bfloat16_t"; + os << '(' << std::scientific << op->value << 'f' << ')'; + return; + } else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() || + op->dtype.is_float8_e4m3fn()) { + os << "fp8_e4_t"; + os << '(' << std::scientific << op->value << 'f' << ')'; + return; + } + // Type code is kFloat + switch (op->dtype.bits()) { + case 64: + case 32: { + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << ((op->dtype.bits() == 32) ? "HUGE_VALF" : "HUGE_VAL"); + } else if (std::isnan(op->value)) { + temp << ((op->dtype.bits() == 32) ? "NAN" : "NAN"); + } else { + temp << std::scientific << op->value; + if (op->dtype.bits() == 32) + temp << 'f'; + } + p->MarkConst(temp.str()); + os << temp.str(); + break; + } + case 16: { + os << "half_t" << '('; + FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); + PrintConst(const_f32.get(), os, p); + os << ')'; + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + PrintConst(op, os, this); +} + +void CodeGenTileLangHIP::HandleVolatileLoads(const std::string &value, + const BufferLoadNode *op, + std::ostream &os) { + // Cast away volatile qualifier for fp16 types. That is, only loads and + // stores are volatile. The loaded objects are not marked as volatile. + // + if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && + IsVolatile(op->buffer->data.get())) { + os << "("; + PrintType(op->dtype, os); + os << ")(" << value << ")"; + } else { + os << value; + } +} + +void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, + const std::string &value, + std::ostream &os) { + ICHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (!(t.lanes() == 2 || t.lanes() == 3)) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 + << "))"; + return; + } + } + + if (t.is_float16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_half2(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (t.is_bfloat16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_bfloat162(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << "("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + return; +} + +void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) { + // clear previous generated state. + this->InitFuncState(f); + // reserve keywords + ReserveKeywordsAsUnique(); + + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.has_value()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); + std::unordered_set non_restrict; + if (auto opt = + f->GetAttr>(tl::attr::kNonRestrictParams)) { + for (const tir::Var &v : opt.value()) + non_restrict.insert(v.get()); + } + + this->PrintFuncPrefix(stream); + CodeGenC::PrintType(f->ret_type, stream); + this->PrintExtraAttrs(f, stream); + this->stream << " " << static_cast(global_symbol.value()) << "("; + for (size_t i = 0; i < f->params.size(); ++i) { + tir::Var v = f->params[i]; + std::string vid = AllocVarID(v.get()); + if (i != 0) + stream << ", "; + if (v.dtype().is_handle()) { + // work around for grid constant parameters. + if (auto *ptr = v->type_annotation.as()) { + if (ptr->storage_scope == "grid_constant") { + stream << "__grid_constant__ const "; + CodeGenC::PrintType(ptr->element_type, stream); + stream << ' ' << vid; + continue; + } + } + + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); + } + + CodeGenC::PrintType(GetType(v), stream); + if (auto *ptr = v->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } + + if (no_alias && !non_restrict.count(v.get())) { + PrintRestrict(v, stream); + } + } else { + CodeGenC::PrintType(GetType(v), stream); + } + stream << ' ' << vid; + } + stream << ") {\n"; + this->PreFunctionBody(f); + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_hip.h b/tilelang/original/src/target/codegen_hip.h new file mode 100644 index 0000000000000000000000000000000000000000..631050feb61059b2594dad0c6e55f2eb4ab46f35 --- /dev/null +++ b/tilelang/original/src/target/codegen_hip.h @@ -0,0 +1,96 @@ +/*! + * \file target/codegen.h + * \brief Utility to generate code + */ +#ifndef TVM_TL_TARGET_CODEGEN_HIP_H_ +#define TVM_TL_TARGET_CODEGEN_HIP_H_ + +#include +#include +#include + +#include +#include + +#include "target/source/codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangHIP final : public CodeGenC { +public: + CodeGenTileLangHIP(); + std::string Finish(); + // override behavior + void PrintFuncPrefix(std::ostream &os) final; + void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final; + void VisitStmt_(const ForNode *op) final; + void PrintStorageSync(const CallNode *op) final; + void PrintStorageScope(const std::string &scope, + std::ostream &os) final; // NOLINT(*) + void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string &vec, DataType t, int i, + std::ostream &os) final; // NOLINT(*) + void PrintVecElemStore(const std::string &vec, DataType t, int i, + const std::string &value) final; + void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) + void PrintVecElemLoadExpr(DataType t, int i, const std::string &value, + std::ostream &os) final; + std::string CastFromTo(std::string value, DataType from, + DataType target) final; + // overload visitor + void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; + void VisitExpr_(const CallNode *op, std::ostream &os) final; + void VisitExpr_(const CastNode *op, std::ostream &os) final; + void VisitStmt_(const AllocateNode *op) final; + void VisitStmt_(const AttrStmtNode *op) final; + + // Override this as a work around for __grid_constant__ parameter + void AddFunction(const PrimFunc &f); + +protected: + virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, + PrimExpr index) final; + void PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, bool skip_first_arg, + std::ostream &os) final; // NOLINT(*) + +private: + // Handle volatile loads + void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op, + std::ostream &os) final; + + // Whether scope such as "__shared__" or "__constant__" is part of type. + bool IsScopePartOfType() const final { return false; } + + friend void PrintConst(const FloatImmNode *op, std::ostream &os, + CodeGenTileLangHIP *p); + + // whether need math_constants.h + bool need_math_constants_h_{false}; + // whether need mfma.h + bool need_wmma_h_{false}; + // whether need fp8.h + bool enable_fp8_{false}; + // The size of the barrier array in shared memory + int barrier_count_ = -1; + // whether need mma.h + bool need_mma_h_{false}; + // whether need cast_smem_ptr_to_int helper function + bool need_cast_smem_ptr_to_int_{false}; + // The name of the barrier array in shared memory + const std::string barrier_name_ = "barrier"; + // The alignment of the barrier array in shared memory + // Set to 16 to maintain minimum alignment requirements for async bulk copy + const int barrier_alignment_bytes_ = 16; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TL_TARGET_CODEGEN_HIP_H_ diff --git a/tilelang/original/src/target/codegen_py.cc b/tilelang/original/src/target/codegen_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa12eef094a36fc490e42c4706062c283e90d272 --- /dev/null +++ b/tilelang/original/src/target/codegen_py.cc @@ -0,0 +1,715 @@ +/*! + * \file codegen_py.cc + */ +#include "codegen_py.h" +#include "codegen_utils.h" + +#include +#include + +#include + +namespace tvm { +namespace codegen { + +void CodeGenTileLangPY::AddFunction(const GlobalVar &gvar, const PrimFunc &f) { + RegisterFunction_(gvar, f); + auto function_name = GetFunctionName_(gvar); + + // clear previous generated state. + InitFuncState_(f); + + PrintFuncDecorator_(stream); + PrintFunctionSignature_(function_name, f, stream); + stream << ":\n"; + + int func_scope = BeginScope(); + PreFunctionBody_(f); + PrintStmt_(f->body); + EndScope(func_scope); +} + +std::string CodeGenTileLangPY::Finish() { + std::ostringstream code; + code << decl_stream.str(); + code << stream.str(); + return code.str(); +} + +ffi::String CodeGenTileLangPY::GetFunctionName_(const GlobalVar &gvar) { + auto it = internal_functions_.find(gvar); + ICHECK(it != internal_functions_.end()) + << "Attempted to find name of " << gvar + << ", but no function with this GlobalVar has been declared"; + return it->second; +} + +void CodeGenTileLangPY::RegisterFunction_(const GlobalVar &gvar, + const PrimFunc &func) { + if (internal_functions_.count(gvar)) { + return; + } + + auto function_name = [&]() -> ffi::String { + if (auto global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = global_symbol.value(); + ICHECK(!func_name_supply_->ContainsName(name)) + << "Function " << gvar << " must use global symbol " << name + << ", but this name has already been used."; + func_name_supply_->ReserveName(name); + return name; + } else { + ICHECK(!func_name_supply_->ContainsName(gvar->name_hint)) + << "Function " << gvar << " must use name hint " << gvar->name_hint + << ", but this name has already been used."; + func_name_supply_->ReserveName(gvar->name_hint); + return gvar->name_hint; + } + }(); + internal_functions_.insert({gvar, function_name}); +} + +void CodeGenTileLangPY::InitFuncState_(const PrimFunc &f) { + alloc_storage_scope_.clear(); + handle_data_type_.clear(); + CodeGenSourceBase::ClearFuncState(); + ReserveKeywordsAsUnique_(); +} + +void CodeGenTileLangPY::PrintFunctionSignature_( + const ffi::String &function_name, const PrimFunc &func, + std::ostream &os) { // NOLINT(*) + os << "def " << function_name << "("; + for (size_t i = 0; i < func->params.size(); ++i) { + tir::Var v = func->params[i]; + if (i > 0) { + os << ", "; + } + os << AllocVarID(v.get()); + } + os << ")"; + + // Register handle data type + for (const auto ¶m : func->params) { + if (auto *ptr = param->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType_(param.get(), prim->dtype); + } + } + } +} + +void CodeGenTileLangPY::ReserveKeywordsAsUnique_() { + // skip the first underscore, so SSA variable starts from _1 + name_supply_->ReserveName("_"); + name_supply_->ReserveName("False"); + name_supply_->ReserveName("None"); + name_supply_->ReserveName("True"); + name_supply_->ReserveName("and"); + name_supply_->ReserveName("as"); + name_supply_->ReserveName("assert"); + name_supply_->ReserveName("async"); + name_supply_->ReserveName("await"); + name_supply_->ReserveName("break"); + name_supply_->ReserveName("class"); + name_supply_->ReserveName("continue"); + name_supply_->ReserveName("def"); + name_supply_->ReserveName("del"); + name_supply_->ReserveName("elif"); + name_supply_->ReserveName("else"); + name_supply_->ReserveName("except"); + name_supply_->ReserveName("finally"); + name_supply_->ReserveName("for"); + name_supply_->ReserveName("from"); + name_supply_->ReserveName("global"); + name_supply_->ReserveName("if"); + name_supply_->ReserveName("import"); + name_supply_->ReserveName("in"); + name_supply_->ReserveName("is"); + name_supply_->ReserveName("lambda"); + name_supply_->ReserveName("nonlocal"); + name_supply_->ReserveName("not"); + name_supply_->ReserveName("or"); + name_supply_->ReserveName("pass"); + name_supply_->ReserveName("raise"); + name_supply_->ReserveName("return"); + name_supply_->ReserveName("try"); + name_supply_->ReserveName("while"); + name_supply_->ReserveName("with"); + name_supply_->ReserveName("yield"); + + name_supply_->ReserveName("void"); + name_supply_->ReserveName("int"); + name_supply_->ReserveName("float"); + name_supply_->ReserveName("double"); + name_supply_->ReserveName("char"); + name_supply_->ReserveName("unsigned"); + name_supply_->ReserveName("short"); + name_supply_->ReserveName("long"); + + name_supply_->ReserveName("cutlass"); + name_supply_->ReserveName("cute"); + name_supply_->ReserveName("tl"); +} + +void CodeGenTileLangPY::PrintSSAAssign(const std::string &target, + const std::string &src, DataType t) { + stream << target << " = " << RemoveOutermostParentheses(src) << "\n"; +} + +void CodeGenTileLangPY::PrintType(DataType type, + std::ostream &os) { // NOLINT(*) + if (type.is_float()) { + if (type.bits() == 16 || type.bits() == 32 || type.bits() == 64) { + os << "float"; + } else { + LOG(FATAL) << "Cannot convert float" << type.bits() << " to Python type"; + } + } else if (type.is_uint()) { + switch (type.bits()) { + case 8: + case 16: + case 32: + case 64: { + os << "int"; + break; + } + case 1: + os << "bool"; + break; + default: + LOG(FATAL) << "Cannot convert uint" << type.bits() << " to Python type"; + } + } else if (type.is_int()) { + switch (type.bits()) { + case 8: + case 16: + case 32: + case 64: { + os << "int"; + break; + } + case 1: + os << "bool"; + break; + default: + LOG(FATAL) << "Cannot convert int" << type.bits() << " to Python type"; + } + } else { + LOG(FATAL) << "Cannot convert type " << type << " to Python type"; + } +} + +void CodeGenTileLangPY::VisitExpr_(const VarNode *op, + std::ostream &os) { // NOLINT(*) + os << GetVarID(op); +} + +void CodeGenTileLangPY::VisitExpr_(const IntImmNode *op, + std::ostream &os) { // NOLINT(*) + if (op->dtype == DataType::Bool()) { + os << (op->value ? "True" : "False"); + } else { + std::ostringstream temp; + temp << op->value; + MarkConst(temp.str()); + os << temp.str(); + } +} + +void CodeGenTileLangPY::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + switch (op->dtype.bits()) { + case 64: + case 32: { + std::ostringstream temp; + temp << "float.fromhex('" << std::hexfloat << op->value << "')"; + MarkConst(temp.str()); + os << temp.str(); + break; + } + case 16: { + PrintType(op->dtype, os); + os << "(float.fromhex('" << std::hexfloat << op->value << "'))"; + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenTileLangPY::VisitExpr_(const StringImmNode *op, + std::ostream &os) { // NOLINT(*) + EscapeStringLiteral_(op->value, os); +} + +void CodeGenTileLangPY::VisitExpr_(const CastNode *op, + std::ostream &os) { // NOLINT(*) + std::stringstream value; + PrintExpr_(op->value, value); + os << CastFromTo_(value.str(), op->value.dtype(), op->dtype); +} + +void CodeGenTileLangPY::VisitExpr_(const AddNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("+", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const SubNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("-", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const MulNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("*", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const DivNode *op, + std::ostream &os) { // NOLINT(*) + if (op->dtype.is_int() || op->dtype.is_uint()) { + PrintBinaryExpr_("//", op->dtype, op->a, op->b, os); + } else { + PrintBinaryExpr_("/", op->dtype, op->a, op->b, os); + } +} +void CodeGenTileLangPY::VisitExpr_(const ModNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK(op->dtype.is_int() || op->dtype.is_uint() || op->dtype.is_float()) + << "Expected floating point or integer dtype in Mod, but got " + << op->dtype; + PrintBinaryExpr_("%", op->dtype, op->a, op->b, os); +} + +void CodeGenTileLangPY::VisitExpr_(const MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("min", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("max", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const EQNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("==", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const NENode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("!=", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const LTNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("<", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const LENode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("<=", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const GTNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_(">", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const GENode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_(">=", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const AndNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("and", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const OrNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("or", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const NotNode *op, + std::ostream &os) { // NOLINT(*) + os << "(not "; + PrintExpr_(op->a, os); + os << ")"; +} + +void CodeGenTileLangPY::VisitExpr_(const SelectNode *op, + std::ostream &os) { // NOLINT(*) + os << "("; + PrintExpr_(op->true_value, os); + os << " if "; + PrintExpr_(op->condition, os); + os << " else "; + PrintExpr_(op->false_value, os); + os << ")"; +} + +void CodeGenTileLangPY::VisitExpr_(const RampNode *op, + std::ostream &os) { // NOLINT(*) + int lanes = op->dtype.lanes(); + os << "("; + for (int i = 0; i < lanes; i++) { + os << "(" << PrintExpr_(op->base) << ")" + << "+(" << PrintExpr_(op->stride) << "*" << i << ")"; + if (i != lanes - 1) + os << ", "; + } + os << ")"; +} + +void CodeGenTileLangPY::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + if (auto opt_call_op = op->op.as()) { + const auto &call_op = opt_call_op.value(); + + if (op->op.same_as(builtin::ret())) { + os << "return " << RemoveOutermostParentheses(PrintExpr_(op->args[0])); + } else if (op->op.same_as(builtin::continue_loop())) { + os << "continue"; + } else if (op->op.same_as(builtin::break_loop())) { + os << "break"; + } else if (op->op.same_as(builtin_call_extern_) || + op->op.same_as(builtin_call_pure_extern_)) { + ICHECK_GE(op->args.size(), 1U); + auto func = Downcast(op->args[0]); + PrintCallExtern_(GetType(ffi::GetRef(op)), func->value, + op->args, true, os); + } else if (op_attr_global_symbol_.count(call_op)) { + // call extern if the op itself have a global symbol. + PrintCallExtern_(GetType(ffi::GetRef(op)), + op_attr_global_symbol_[call_op], op->args, false, os); + } else if (op->op.same_as(builtin::large_uint_imm())) { + ICHECK_EQ(op->args.size(), 2U); + uint64_t low = + static_cast(Downcast(op->args[0])->value); + uint64_t high = + static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + + if (op->dtype == DataType::UInt(32)) { + std::ostringstream temp; + temp << val; + MarkConst(temp.str()); + os << temp.str(); + } else { + PrintType(op->dtype, os); + os << "(" << val << ")"; + } + } else if (op->op.same_as(builtin::bitwise_and())) { + PrintBinaryIntrinsic_(op, "&", os); + } else if (op->op.same_as(builtin::bitwise_or())) { + PrintBinaryIntrinsic_(op, "|", os); + } else if (op->op.same_as(builtin::bitwise_xor())) { + PrintBinaryIntrinsic_(op, "^", os); + } else if (op->op.same_as(builtin::bitwise_not())) { + ICHECK_EQ(op->args.size(), 1U); + os << "~"; + PrintExpr_(op->args[0], os); + } else if (op->op.same_as(builtin::shift_left())) { + PrintBinaryIntrinsic_(op, "<<", os); + } else if (op->op.same_as(builtin::shift_right())) { + PrintBinaryIntrinsic_(op, ">>", os); + } else if (op->op.same_as(builtin::if_then_else())) { + + std::string cond = PrintExpr_(op->args[0]); + std::string true_val = PrintExpr_(op->args[1]); + std::string false_val = PrintExpr_(op->args[2]); + os << "(" << true_val << " if " << cond << " else " << false_val << ")"; + } else if (op->op.same_as(builtin::isnullptr())) { + ICHECK_EQ(op->args.size(), 1U); + os << "("; + PrintExpr_(op->args[0], os); + os << " is None)"; + } else if (op->op.same_as(builtin::isnan())) { + os << "("; + PrintExpr_(op->args[0], os); + os << " != "; + PrintExpr_(op->args[0], os); + os << ")"; + } else { + LOG(FATAL) << "Unresolved call " << op->op; + } + } else if (auto opt = op->op.as()) { + const auto &gvar = opt.value(); + auto callee_name = GetFunctionName_(gvar); + PrintCallExtern_(GetType(ffi::GetRef(op)), callee_name, op->args, + false, os); + } else { + LOG(FATAL) + << "CodeGenTileLangPY: Unknown operation " << op->op + << " is neither a recognized built-in, " + << "nor a GlobalVar reference to another function in the IRModule"; + } +} + +void CodeGenTileLangPY::VisitExpr_(const BufferLoadNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) + << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + ICHECK_EQ(value_dtype, element_dtype) + << "value_dtype and element_dtype must be same for a BufferLoadNode"; + std::string ref = GetBufferRef_(op->dtype, op->buffer.get(), index); + os << ref; +} + +void CodeGenTileLangPY::VisitStmt_(const BufferStoreNode *op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not supported."; + + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + + ICHECK_EQ(value_dtype, element_dtype) + << "value_dtype and element_dtype must be same for a BufferStoreNode"; + std::string value = PrintExpr_(op->value); + std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr); + PrintIndent(); + stream << ref << " = " << RemoveOutermostParentheses(value) << "\n"; +} + +void CodeGenTileLangPY::VisitStmt_(const DeclBufferNode *op) { + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const LetStmtNode *op) { + std::string value = PrintExpr_(op->value); + PrintIndent(); + stream << AllocVarID(op->var.get()) << " = " << value << "\n"; + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + PrintIndent(); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + + stream << vid << " = [None] * " << constant_size << "\n"; + + RegisterHandleType_(op->buffer_var.get(), op->dtype); + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const AttrStmtNode *op) { + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const ForNode *op) { + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + stream << "for " << vid << " in range("; + if (is_zero(op->min)) { + PrintExpr_(op->extent, stream); + } else { + PrintExpr_(op->min, stream); + stream << ", "; + PrimExpr upper_bound = arith::Analyzer().Simplify(op->extent + op->min); + PrintExpr_(upper_bound, stream); + } + stream << "):\n"; + int for_scope = BeginScope(); + PrintStmt_(op->body); + EndScope(for_scope); +} + +void CodeGenTileLangPY::VisitStmt_(const WhileNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + stream << "while " << RemoveOutermostParentheses(cond) << ":\n"; + int while_scope = BeginScope(); + PrintStmt_(op->body); + EndScope(while_scope); +} + +void CodeGenTileLangPY::VisitStmt_(const IfThenElseNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + stream << "if " << RemoveOutermostParentheses(cond) << ":\n"; + int then_scope = BeginScope(); + PrintStmt_(op->then_case); + EndScope(then_scope); + + if (op->else_case) { + PrintIndent(); + stream << "else:\n"; + int else_scope = BeginScope(); + PrintStmt_(op->else_case.value()); + EndScope(else_scope); + } +} + +void CodeGenTileLangPY::VisitStmt_(const SeqStmtNode *op) { + for (Stmt stmt : op->seq) { + PrintStmt_(stmt); + } +} + +void CodeGenTileLangPY::VisitStmt_(const EvaluateNode *op) { + if (is_const_int(op->value)) + return; + + std::string vid = PrintExpr_(op->value); + if (!vid.empty()) { + PrintIndent(); + stream << vid << "\n"; + } +} + +void CodeGenTileLangPY::VisitStmt_(const AssertStmtNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + if (const auto *str = op->message.as()) { + stream << "assert " << cond << ", "; + EscapeStringLiteral_(str->value, stream); + stream << "\n"; + } else { + stream << "assert " << cond << "\n"; + } + PrintStmt_(op->body); +} + +std::string CodeGenTileLangPY::CastFromTo_(const std::string &value, + DataType from, DataType target) { + if (from == target) + return value; + std::ostringstream os; + PrintType(target, os); + os << "(" << value << ")"; + return os.str(); +} + +void CodeGenTileLangPY::PrintBinaryExpr_(const std::string &opstr, + DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(dtype.lanes(), 1); + if (isalpha(opstr[0]) && opstr != "and" && opstr != "or") { + os << opstr << '('; + PrintExpr_(lhs, os); + os << ", "; + PrintExpr_(rhs, os); + os << ')'; + } else { + os << '('; + PrintExpr_(lhs, os); + os << ' ' << opstr << ' '; + PrintExpr_(rhs, os); + os << ')'; + } +} + +void CodeGenTileLangPY::PrintBinaryIntrinsic_(const CallNode *op, + const char *opstr, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->dtype.lanes(), 1); + ICHECK_EQ(op->args.size(), 2U); + os << '('; + PrintExpr_(op->args[0], os); + os << ' ' << opstr << ' '; + PrintExpr_(op->args[1], os); + os << ')'; +} + +void CodeGenTileLangPY::PrintCallExtern_(Type ret_type, + ffi::String global_symbol, + const ffi::Array &args, + bool skip_first_arg, + std::ostream &os) { // NOLINT(*) + os << global_symbol << "("; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + PrintExpr_(args[i], os); + if (i < args.size() - 1) { + os << ", "; + } + } + os << ")"; +} + +// Print a reference expression to a buffer. +std::string CodeGenTileLangPY::GetBufferRef_(DataType t, + const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + std::string vid = GetVarID(buffer_var); + DataType buffer_element_dtype = buffer->dtype; + + ICHECK(HandleTypeMatch_(buffer_var, buffer_element_dtype)); + ICHECK_EQ(t, buffer_element_dtype); + + std::string index_str = PrintExpr_(index); + return vid + "[" + index_str + "]"; +} + +void CodeGenTileLangPY::RegisterHandleType_(const VarNode *buf_var, + DataType t) { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) { + handle_data_type_[buf_var] = t; + } else { + ICHECK(it->second == t) << "conflicting buf var type"; + } +} + +bool CodeGenTileLangPY::HandleTypeMatch_(const VarNode *buf_var, + DataType t) const { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) + return false; + return it->second == t; +} + +void CodeGenTileLangPY::EscapeStringLiteral_(const std::string &s, + std::ostream &os) { + os << '"'; + for (unsigned char c : s) { + switch (c) { + case '\\': + os << "\\\\"; + break; + case '"': + os << "\\\""; + break; + case '\n': + os << "\\n"; + break; + case '\r': + os << "\\r"; + break; + case '\t': + os << "\\t"; + break; + case '\f': + os << "\\f"; + break; + case '\b': + os << "\\b"; + break; + default: + // Handle non-printable and non-ASCII characters + if (c < 32 || c == 127) { + // Output as \xHH + os << "\\x"; + const char hex[] = "0123456789abcdef"; + os << hex[(c >> 4) & 0xF]; + os << hex[c & 0xF]; + } else { + os << c; + } + break; + } + } + os << '"'; +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_py.h b/tilelang/original/src/target/codegen_py.h new file mode 100644 index 0000000000000000000000000000000000000000..431fe933d308527cd50dbd151bcefa7713d2b775 --- /dev/null +++ b/tilelang/original/src/target/codegen_py.h @@ -0,0 +1,255 @@ +/*! + * \file codegen_py.h + * \brief Common utilities to generate simple Python code. + */ +#ifndef TVM_TL_TARGET_CODEGEN_PY_H_ +#define TVM_TL_TARGET_CODEGEN_PY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// from tvm/src/ +#include "target/source/codegen_source_base.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace codegen { + +using namespace tir; +/*! + * \brief A base class to generate simple Python code. + */ +class CodeGenTileLangPY + : public ExprFunctor, + public StmtFunctor, + public CodeGenSourceBase { +public: + /*! + * \brief Add the function definition to the generated module. + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. + */ + virtual void AddFunction(const GlobalVar &gvar, const PrimFunc &func); + + /*! + * \brief Finalize the compilation and return the code. + * \return The code. + */ + virtual std::string Finish(); + +protected: + /*! + * \brief Get the name of a declared function + * \param gvar The GlobalVar of the function + * \returns The string name of the function + */ + ffi::String GetFunctionName_(const GlobalVar &gvar); + + /*! + * \brief Reserve the function name in the generated module. + * + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. + * \param whether to append return 0 in the end. + */ + virtual void RegisterFunction_(const GlobalVar &gvar, const PrimFunc &func); + + /*! + * \brief Initialize codegen state for generating f. + * \param f The function to be compiled. + */ + virtual void InitFuncState_(const PrimFunc &f); + + /*! \brief Print the function signature before ":" + * \param function_name The name of the function + * \param func The function whose signature should be printed + * \param os The output stream + */ + virtual void PrintFunctionSignature_(const ffi::String &function_name, + const PrimFunc &func, + std::ostream &os); // NOLINT(*) + + /*! + * \brief Print the function decorator + * \param os The output stream + */ + virtual void PrintFuncDecorator_(std::ostream &os) {} // NOLINT(*) + + /*! + * \brief Insert statement before function body. + * \param f The function to be compiled. + */ + virtual void PreFunctionBody_(const PrimFunc &f) {} + +protected: + /*! \brief reserves common Python keywords */ + void ReserveKeywordsAsUnique_(); + + void PrintSSAAssign(const std::string &target, const std::string &src, + DataType t) override; + +protected: + /*! + * \brief Print Type representation of type type. + * \param t The type representation. + * \param os The output stream + */ + void PrintType(DataType type, std::ostream &os) override; // NOLINT(*) + + /*! + * \brief Print the Stmt n to CodeGenTileLangPY->stream + * \param n The statement to be printed. + */ + void PrintStmt_(const Stmt &n) { VisitStmt(n); } + /*! + * \brief Print the expression n into os + * \param n The expression to be printed. + * \param os The output stream + */ + void PrintExpr_(const PrimExpr &n, std::ostream &os) { // NOLINT(*) + VisitExpr(n, os); + } + /*! + * \brief Same as PrintExpr_, but simply returns result string + * \param n The expression to be printed. + */ + std::string PrintExpr_(const PrimExpr &n) { + std::ostringstream os; + PrintExpr_(n, os); + return os.str(); + } + + // expression + void VisitExpr_(const VarNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const StringImmNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const AddNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const SubNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MulNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const ModNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const EQNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const NENode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const LTNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const LENode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const GTNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const GENode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const AndNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const OrNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const NotNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const RampNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode *op, + std::ostream &os) override; // NOLINT(*) + + // statment + void VisitStmt_(const BufferStoreNode *op) override; + void VisitStmt_(const DeclBufferNode *op) override; + void VisitStmt_(const LetStmtNode *op) override; + void VisitStmt_(const AllocateNode *op) override; + void VisitStmt_(const AttrStmtNode *op) override; + void VisitStmt_(const ForNode *op) override; + void VisitStmt_(const WhileNode *op) override; + void VisitStmt_(const IfThenElseNode *op) override; + void VisitStmt_(const SeqStmtNode *op) override; + void VisitStmt_(const EvaluateNode *op) override; + void VisitStmt_(const AssertStmtNode *op) override; + +protected: + // Get a string of type casting + virtual std::string CastFromTo_(const std::string &value, DataType from, + DataType target); + + virtual void PrintBinaryExpr_(const std::string &opstr, DataType dtype, + PrimExpr lhs, PrimExpr rhs, + std::ostream &os); // NOLINT(*) + virtual void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr, + std::ostream &os); // NOLINT(*) + + /*! + * \brief Print external function call. + * \param ret_type The return type. + * \param global_symbol The symbolc of the target function. + * \param args The arguments to the function. + * \param skip_first_arg Whether to skip the first arguments. + * \param os The output stream. + */ + virtual void PrintCallExtern_(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, + bool skip_first_arg, + std::ostream &os); // NOLINT(*) + + // Print reference to a buffer as type t in index. + virtual std::string GetBufferRef_(DataType t, const BufferNode *buffer, + PrimExpr index); + + /*! + * \brief Register the data type of buf_var + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + void RegisterHandleType_(const VarNode *buf_var, DataType t); + + /*! + * \brief If buffer is allocated as type t. + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + bool HandleTypeMatch_(const VarNode *buf_var, DataType t) const; + +protected: + /*! \brief the storage scope of allocation */ + std::unordered_map alloc_storage_scope_; + + /*! \brief Record of ops that have pre-defined global symbol. */ + OpAttrMap op_attr_global_symbol_ = + Op::GetAttrMap("TGlobalSymbol"); + + // cache commonly used ops + const Op &builtin_call_extern_ = builtin::call_extern(); + const Op &builtin_call_pure_extern_ = builtin::call_pure_extern(); + +private: + /*! \brief the data type of allocated buffers */ + std::unordered_map handle_data_type_; + + /* \brief Map of GlobalVar to their symbol. + * + * For externally-exposed functions, this is given by the + * tvm::attr::kTarget attribute of the PrimFunc. For internal + * functions, this is the name of the function's GlobalVar, possibly + * altered to prevent duplicate names. + */ + std::unordered_map internal_functions_; + + /* \brief Name supply to generate unique function names */ + NameSupply func_name_supply_; + + /*! + * \brief Escape a string to be a valid Python double-quoted string literal. + * \param s The input string to escape. + * \param os The output stream to write the escaped string to. + */ + void EscapeStringLiteral_(const std::string &s, std::ostream &os); +}; + +} // namespace codegen +} // namespace tvm +#endif // TVM_TL_TARGET_CODEGEN_PY_H_ diff --git a/tilelang/original/src/target/codegen_utils.cc b/tilelang/original/src/target/codegen_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..75d038d3a74047e2afd6b28f25a60b76a21e01d3 --- /dev/null +++ b/tilelang/original/src/target/codegen_utils.cc @@ -0,0 +1,41 @@ +/*! + * \file target/codegen_utils.cc + * \brief Shared utility functions for code generation + */ + +#include "codegen_utils.h" + +namespace tvm { +namespace codegen { + +bool CheckOutermostParenthesesMatch(const std::string &s) { + if (!s.empty() && s.front() == '(' && s.back() == ')') { + size_t len = s.size(); + int n_unmatched = 0; + for (size_t i = 0; i < len; ++i) { + if (s[i] == '(') { + n_unmatched++; + } else if (s[i] == ')') { + n_unmatched--; + } + if (n_unmatched < 0) { + return false; + } + if (n_unmatched == 0) { + return i == len - 1; + } + } + } + return false; +} + +std::string RemoveOutermostParentheses(const std::string &s) { + if (CheckOutermostParenthesesMatch(s)) { + return s.substr(1, s.size() - 2); + } else { + return s; + } +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/codegen_utils.h b/tilelang/original/src/target/codegen_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1ef52d4b105cbd548dca9cce67a338ab4325cb9a --- /dev/null +++ b/tilelang/original/src/target/codegen_utils.h @@ -0,0 +1,33 @@ +/*! + * \file target/codegen_utils.h + * \brief Shared utility functions for code generation + */ + +#ifndef TVM_TARGET_CODEGEN_UTILS_H_ +#define TVM_TARGET_CODEGEN_UTILS_H_ + +#include + +namespace tvm { +namespace codegen { + +/*! + * \brief Check if the outermost parentheses match + * \param s The input string + * \return true if the first character is '(' and the last character is ')' + * and they form a matching pair + */ +bool CheckOutermostParenthesesMatch(const std::string &s); + +/*! + * \brief Remove outermost parentheses if they match + * \param s The input string + * \return The string with outermost parentheses removed if they match, + * otherwise return the original string + */ +std::string RemoveOutermostParentheses(const std::string &s); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_CODEGEN_UTILS_H_ diff --git a/tilelang/original/src/target/cuda.h b/tilelang/original/src/target/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..a9dfb13ab03eaad29738a473d8d83c41e0007461 --- /dev/null +++ b/tilelang/original/src/target/cuda.h @@ -0,0 +1,26649 @@ +/* + * Copyright 1993-2023 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef __cuda_cuda_h__ +#define __cuda_cuda_h__ + +#include +#ifdef _MSC_VER +typedef unsigned __int32 cuuint32_t; +typedef unsigned __int64 cuuint64_t; +#else +#include +typedef uint32_t cuuint32_t; +typedef uint64_t cuuint64_t; +#endif + +#if defined(__CUDA_API_VERSION_INTERNAL) || defined(__DOXYGEN_ONLY__) || \ + defined(CUDA_ENABLE_DEPRECATED) +#define __CUDA_DEPRECATED +#elif defined(_MSC_VER) +#define __CUDA_DEPRECATED __declspec(deprecated) +#elif defined(__GNUC__) +#define __CUDA_DEPRECATED __attribute__((deprecated)) +#else +#define __CUDA_DEPRECATED +#endif + +#if defined(CUDA_FORCE_API_VERSION) +#error "CUDA_FORCE_API_VERSION is no longer supported." +#endif + +#if defined(__CUDA_API_VERSION_INTERNAL) || \ + defined(CUDA_API_PER_THREAD_DEFAULT_STREAM) +#define __CUDA_API_PER_THREAD_DEFAULT_STREAM +#define __CUDA_API_PTDS(api) api##_ptds +#define __CUDA_API_PTSZ(api) api##_ptsz +#else +#define __CUDA_API_PTDS(api) api +#define __CUDA_API_PTSZ(api) api +#endif + +#define cuDeviceTotalMem cuDeviceTotalMem_v2 +#define cuCtxCreate cuCtxCreate_v2 +#define cuCtxCreate_v3 cuCtxCreate_v3 +#define cuModuleGetGlobal cuModuleGetGlobal_v2 +#define cuMemGetInfo cuMemGetInfo_v2 +#define cuMemAlloc cuMemAlloc_v2 +#define cuMemAllocPitch cuMemAllocPitch_v2 +#define cuMemFree cuMemFree_v2 +#define cuMemGetAddressRange cuMemGetAddressRange_v2 +#define cuMemAllocHost cuMemAllocHost_v2 +#define cuMemHostGetDevicePointer cuMemHostGetDevicePointer_v2 +#define cuMemcpyHtoD __CUDA_API_PTDS(cuMemcpyHtoD_v2) +#define cuMemcpyDtoH __CUDA_API_PTDS(cuMemcpyDtoH_v2) +#define cuMemcpyDtoD __CUDA_API_PTDS(cuMemcpyDtoD_v2) +#define cuMemcpyDtoA __CUDA_API_PTDS(cuMemcpyDtoA_v2) +#define cuMemcpyAtoD __CUDA_API_PTDS(cuMemcpyAtoD_v2) +#define cuMemcpyHtoA __CUDA_API_PTDS(cuMemcpyHtoA_v2) +#define cuMemcpyAtoH __CUDA_API_PTDS(cuMemcpyAtoH_v2) +#define cuMemcpyAtoA __CUDA_API_PTDS(cuMemcpyAtoA_v2) +#define cuMemcpyHtoAAsync __CUDA_API_PTSZ(cuMemcpyHtoAAsync_v2) +#define cuMemcpyAtoHAsync __CUDA_API_PTSZ(cuMemcpyAtoHAsync_v2) +#define cuMemcpy2D __CUDA_API_PTDS(cuMemcpy2D_v2) +#define cuMemcpy2DUnaligned __CUDA_API_PTDS(cuMemcpy2DUnaligned_v2) +#define cuMemcpy3D __CUDA_API_PTDS(cuMemcpy3D_v2) +#define cuMemcpyHtoDAsync __CUDA_API_PTSZ(cuMemcpyHtoDAsync_v2) +#define cuMemcpyDtoHAsync __CUDA_API_PTSZ(cuMemcpyDtoHAsync_v2) +#define cuMemcpyDtoDAsync __CUDA_API_PTSZ(cuMemcpyDtoDAsync_v2) +#define cuMemcpy2DAsync __CUDA_API_PTSZ(cuMemcpy2DAsync_v2) +#define cuMemcpy3DAsync __CUDA_API_PTSZ(cuMemcpy3DAsync_v2) +#define cuMemsetD8 __CUDA_API_PTDS(cuMemsetD8_v2) +#define cuMemsetD16 __CUDA_API_PTDS(cuMemsetD16_v2) +#define cuMemsetD32 __CUDA_API_PTDS(cuMemsetD32_v2) +#define cuMemsetD2D8 __CUDA_API_PTDS(cuMemsetD2D8_v2) +#define cuMemsetD2D16 __CUDA_API_PTDS(cuMemsetD2D16_v2) +#define cuMemsetD2D32 __CUDA_API_PTDS(cuMemsetD2D32_v2) +#define cuArrayCreate cuArrayCreate_v2 +#define cuArrayGetDescriptor cuArrayGetDescriptor_v2 +#define cuArray3DCreate cuArray3DCreate_v2 +#define cuArray3DGetDescriptor cuArray3DGetDescriptor_v2 +#define cuTexRefSetAddress cuTexRefSetAddress_v2 +#define cuTexRefGetAddress cuTexRefGetAddress_v2 +#define cuGraphicsResourceGetMappedPointer cuGraphicsResourceGetMappedPointer_v2 +#define cuCtxDestroy cuCtxDestroy_v2 +#define cuCtxPopCurrent cuCtxPopCurrent_v2 +#define cuCtxPushCurrent cuCtxPushCurrent_v2 +#define cuStreamDestroy cuStreamDestroy_v2 +#define cuEventDestroy cuEventDestroy_v2 +#define cuTexRefSetAddress2D cuTexRefSetAddress2D_v3 +#define cuLinkCreate cuLinkCreate_v2 +#define cuLinkAddData cuLinkAddData_v2 +#define cuLinkAddFile cuLinkAddFile_v2 +#define cuMemHostRegister cuMemHostRegister_v2 +#define cuGraphicsResourceSetMapFlags cuGraphicsResourceSetMapFlags_v2 +#define cuStreamBeginCapture __CUDA_API_PTSZ(cuStreamBeginCapture_v2) +#define cuDevicePrimaryCtxRelease cuDevicePrimaryCtxRelease_v2 +#define cuDevicePrimaryCtxReset cuDevicePrimaryCtxReset_v2 +#define cuDevicePrimaryCtxSetFlags cuDevicePrimaryCtxSetFlags_v2 +#define cuDeviceGetUuid_v2 cuDeviceGetUuid_v2 +#define cuIpcOpenMemHandle cuIpcOpenMemHandle_v2 + +#define cuGraphInstantiate cuGraphInstantiateWithFlags + +#define cuGraphExecUpdate cuGraphExecUpdate_v2 +#define cuGetProcAddress cuGetProcAddress_v2 +#define cuGraphAddKernelNode cuGraphAddKernelNode_v2 +#define cuGraphKernelNodeGetParams cuGraphKernelNodeGetParams_v2 +#define cuGraphKernelNodeSetParams cuGraphKernelNodeSetParams_v2 +#define cuGraphExecKernelNodeSetParams cuGraphExecKernelNodeSetParams_v2 + +#define cuStreamWriteValue32 __CUDA_API_PTSZ(cuStreamWriteValue32_v2) +#define cuStreamWaitValue32 __CUDA_API_PTSZ(cuStreamWaitValue32_v2) +#define cuStreamWriteValue64 __CUDA_API_PTSZ(cuStreamWriteValue64_v2) +#define cuStreamWaitValue64 __CUDA_API_PTSZ(cuStreamWaitValue64_v2) +#define cuStreamBatchMemOp __CUDA_API_PTSZ(cuStreamBatchMemOp_v2) +#define cuStreamGetCaptureInfo __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) +#define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) + +#if defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) +#define cuMemcpy __CUDA_API_PTDS(cuMemcpy) +#define cuMemcpyAsync __CUDA_API_PTSZ(cuMemcpyAsync) +#define cuMemcpyPeer __CUDA_API_PTDS(cuMemcpyPeer) +#define cuMemcpyPeerAsync __CUDA_API_PTSZ(cuMemcpyPeerAsync) +#define cuMemcpy3DPeer __CUDA_API_PTDS(cuMemcpy3DPeer) +#define cuMemcpy3DPeerAsync __CUDA_API_PTSZ(cuMemcpy3DPeerAsync) +#define cuMemPrefetchAsync __CUDA_API_PTSZ(cuMemPrefetchAsync) +#define cuMemPrefetchAsync_v2 __CUDA_API_PTSZ(cuMemPrefetchAsync_v2) + +#define cuMemsetD8Async __CUDA_API_PTSZ(cuMemsetD8Async) +#define cuMemsetD16Async __CUDA_API_PTSZ(cuMemsetD16Async) +#define cuMemsetD32Async __CUDA_API_PTSZ(cuMemsetD32Async) +#define cuMemsetD2D8Async __CUDA_API_PTSZ(cuMemsetD2D8Async) +#define cuMemsetD2D16Async __CUDA_API_PTSZ(cuMemsetD2D16Async) +#define cuMemsetD2D32Async __CUDA_API_PTSZ(cuMemsetD2D32Async) + +#define cuStreamGetPriority __CUDA_API_PTSZ(cuStreamGetPriority) +#define cuStreamGetId __CUDA_API_PTSZ(cuStreamGetId) +#define cuStreamGetFlags __CUDA_API_PTSZ(cuStreamGetFlags) +#define cuStreamGetCtx __CUDA_API_PTSZ(cuStreamGetCtx) +#define cuStreamWaitEvent __CUDA_API_PTSZ(cuStreamWaitEvent) +#define cuStreamEndCapture __CUDA_API_PTSZ(cuStreamEndCapture) +#define cuStreamIsCapturing __CUDA_API_PTSZ(cuStreamIsCapturing) +#define cuStreamGetCaptureInfo_v3 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v3) +#define cuStreamUpdateCaptureDependencies \ + __CUDA_API_PTSZ(cuStreamUpdateCaptureDependencies) +#define cuStreamUpdateCaptureDependencies_v2 \ + __CUDA_API_PTSZ(cuStreamUpdateCaptureDependencies_v2) +#define cuStreamAddCallback __CUDA_API_PTSZ(cuStreamAddCallback) +#define cuStreamAttachMemAsync __CUDA_API_PTSZ(cuStreamAttachMemAsync) +#define cuStreamQuery __CUDA_API_PTSZ(cuStreamQuery) +#define cuStreamSynchronize __CUDA_API_PTSZ(cuStreamSynchronize) +#define cuEventRecord __CUDA_API_PTSZ(cuEventRecord) +#define cuEventRecordWithFlags __CUDA_API_PTSZ(cuEventRecordWithFlags) +#define cuLaunchKernel __CUDA_API_PTSZ(cuLaunchKernel) +#define cuLaunchKernelEx __CUDA_API_PTSZ(cuLaunchKernelEx) +#define cuLaunchHostFunc __CUDA_API_PTSZ(cuLaunchHostFunc) +#define cuGraphicsMapResources __CUDA_API_PTSZ(cuGraphicsMapResources) +#define cuGraphicsUnmapResources __CUDA_API_PTSZ(cuGraphicsUnmapResources) + +#define cuLaunchCooperativeKernel __CUDA_API_PTSZ(cuLaunchCooperativeKernel) + +#define cuSignalExternalSemaphoresAsync \ + __CUDA_API_PTSZ(cuSignalExternalSemaphoresAsync) +#define cuWaitExternalSemaphoresAsync \ + __CUDA_API_PTSZ(cuWaitExternalSemaphoresAsync) + +#define cuGraphInstantiateWithParams \ + __CUDA_API_PTSZ(cuGraphInstantiateWithParams) +#define cuGraphUpload __CUDA_API_PTSZ(cuGraphUpload) +#define cuGraphLaunch __CUDA_API_PTSZ(cuGraphLaunch) +#define cuStreamCopyAttributes __CUDA_API_PTSZ(cuStreamCopyAttributes) +#define cuStreamGetAttribute __CUDA_API_PTSZ(cuStreamGetAttribute) +#define cuStreamSetAttribute __CUDA_API_PTSZ(cuStreamSetAttribute) +#define cuMemMapArrayAsync __CUDA_API_PTSZ(cuMemMapArrayAsync) + +#define cuMemFreeAsync __CUDA_API_PTSZ(cuMemFreeAsync) +#define cuMemAllocAsync __CUDA_API_PTSZ(cuMemAllocAsync) +#define cuMemAllocFromPoolAsync __CUDA_API_PTSZ(cuMemAllocFromPoolAsync) + +#define cuStreamBeginCaptureToGraph __CUDA_API_PTSZ(cuStreamBeginCaptureToGraph) + +#endif + +/** + * \file cuda.h + * \brief Header file for the CUDA Toolkit application programming interface. + * + * \file cudaGL.h + * \brief Header file for the OpenGL interoperability functions of the + * low-level CUDA driver application programming interface. + * + * \file cudaD3D9.h + * \brief Header file for the Direct3D 9 interoperability functions of the + * low-level CUDA driver application programming interface. + */ + +/** + * \defgroup CUDA_TYPES Data types used by CUDA driver + * @{ + */ + +/** + * CUDA API version number + */ +#define CUDA_VERSION 12040 + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * CUDA device pointer + * CUdeviceptr is defined as an unsigned integer type whose size matches the + * size of a pointer on the target platform. + */ +#if defined(_WIN64) || defined(__LP64__) +typedef unsigned long long CUdeviceptr_v2; +#else +typedef unsigned int CUdeviceptr_v2; +#endif +typedef CUdeviceptr_v2 CUdeviceptr; /**< CUDA device pointer */ + +typedef int CUdevice_v1; /**< CUDA device */ +typedef CUdevice_v1 CUdevice; /**< CUDA device */ +typedef struct CUctx_st *CUcontext; /**< CUDA context */ +typedef struct CUmod_st *CUmodule; /**< CUDA module */ +typedef struct CUfunc_st *CUfunction; /**< CUDA function */ +typedef struct CUlib_st *CUlibrary; /**< CUDA library */ +typedef struct CUkern_st *CUkernel; /**< CUDA kernel */ +typedef struct CUarray_st *CUarray; /**< CUDA array */ +typedef struct CUmipmappedArray_st + *CUmipmappedArray; /**< CUDA mipmapped array */ +typedef struct CUtexref_st *CUtexref; /**< CUDA texture reference */ +typedef struct CUsurfref_st *CUsurfref; /**< CUDA surface reference */ +typedef struct CUevent_st *CUevent; /**< CUDA event */ +typedef struct CUstream_st *CUstream; /**< CUDA stream */ +typedef struct CUgraphicsResource_st + *CUgraphicsResource; /**< CUDA graphics interop resource */ +typedef unsigned long long CUtexObject_v1; /**< An opaque value that represents + a CUDA texture object */ +typedef CUtexObject_v1 + CUtexObject; /**< An opaque value that represents a CUDA texture object */ +typedef unsigned long long CUsurfObject_v1; /**< An opaque value that represents + a CUDA surface object */ +typedef CUsurfObject_v1 + CUsurfObject; /**< An opaque value that represents a CUDA surface object */ +typedef struct CUextMemory_st *CUexternalMemory; /**< CUDA external memory */ +typedef struct CUextSemaphore_st + *CUexternalSemaphore; /**< CUDA external semaphore */ +typedef struct CUgraph_st *CUgraph; /**< CUDA graph */ +typedef struct CUgraphNode_st *CUgraphNode; /**< CUDA graph node */ +typedef struct CUgraphExec_st *CUgraphExec; /**< CUDA executable graph */ +typedef struct CUmemPoolHandle_st *CUmemoryPool; /**< CUDA memory pool */ +typedef struct CUuserObject_st + *CUuserObject; /**< CUDA user object for graphs */ +typedef cuuint64_t + CUgraphConditionalHandle; /**< CUDA graph conditional handle */ +typedef struct CUgraphDeviceUpdatableNode_st + *CUgraphDeviceNode; /**< CUDA graph device node handle */ +typedef struct CUasyncCallbackEntry_st + *CUasyncCallbackHandle; /**< CUDA async notification callback handle */ + +#ifndef CU_UUID_HAS_BEEN_DEFINED +#define CU_UUID_HAS_BEEN_DEFINED +typedef struct CUuuid_st { /**< CUDA definition of UUID */ + char bytes[16]; +} CUuuid; +#endif + +/** + * CUDA IPC handle size + */ +#define CU_IPC_HANDLE_SIZE 64 + +/** + * Fabric handle - An opaque handle representing a memory allocation + * that can be exported to processes in same or different nodes. For IPC + * between processes on different nodes they must be connected via the + * NVSwitch fabric. + */ +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; + +/** + * CUDA IPC event handle + */ +typedef struct CUipcEventHandle_st { + char reserved[CU_IPC_HANDLE_SIZE]; +} CUipcEventHandle_v1; +typedef CUipcEventHandle_v1 CUipcEventHandle; + +/** + * CUDA IPC mem handle + */ +typedef struct CUipcMemHandle_st { + char reserved[CU_IPC_HANDLE_SIZE]; +} CUipcMemHandle_v1; +typedef CUipcMemHandle_v1 CUipcMemHandle; + +/** + * CUDA Ipc Mem Flags + */ +typedef enum CUipcMem_flags_enum { + CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS = + 0x1 /**< Automatically enable peer access between remote devices as needed + */ +} CUipcMem_flags; + +/** + * CUDA Mem Attach Flags + */ +typedef enum CUmemAttach_flags_enum { + CU_MEM_ATTACH_GLOBAL = + 0x1, /**< Memory can be accessed by any stream on any device */ + CU_MEM_ATTACH_HOST = + 0x2, /**< Memory cannot be accessed by any stream on any device */ + CU_MEM_ATTACH_SINGLE = 0x4 /**< Memory can only be accessed by a single stream + on the associated device */ +} CUmemAttach_flags; + +/** + * Context creation flags + */ +typedef enum CUctx_flags_enum { + CU_CTX_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + CU_CTX_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + CU_CTX_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + CU_CTX_SCHED_BLOCKING_SYNC = + 0x04, /**< Set blocking synchronization as default scheduling */ + CU_CTX_BLOCKING_SYNC = + 0x04, /**< Set blocking synchronization as default scheduling + * \deprecated This flag was deprecated as of CUDA 4.0 + * and was replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. */ + CU_CTX_SCHED_MASK = 0x07, + CU_CTX_MAP_HOST = + 0x08, /**< \deprecated This flag was deprecated as of CUDA 11.0 + * and it no longer has any effect. All contexts + * as of CUDA 3.2 behave as though the flag is enabled. */ + CU_CTX_LMEM_RESIZE_TO_MAX = + 0x10, /**< Keep local memory allocation after launch */ + CU_CTX_COREDUMP_ENABLE = + 0x20, /**< Trigger coredumps from exceptions in this context */ + CU_CTX_USER_COREDUMP_ENABLE = + 0x40, /**< Enable user pipe to trigger coredumps in this context */ + CU_CTX_SYNC_MEMOPS = 0x80, /**< Ensure synchronous memory operations on this + context will synchronize */ + CU_CTX_FLAGS_MASK = 0xFF +} CUctx_flags; + +/** + * Event sched flags + */ +typedef enum CUevent_sched_flags_enum { + CU_EVENT_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + CU_EVENT_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + CU_EVENT_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + CU_EVENT_SCHED_BLOCKING_SYNC = + 0x04, /**< Set blocking synchronization as default scheduling */ +} CUevent_sched_flags; + +/** + * NVCL event scheduling flags + */ +typedef enum cl_event_flags_enum { + NVCL_EVENT_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + NVCL_EVENT_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + NVCL_EVENT_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + NVCL_EVENT_SCHED_BLOCKING_SYNC = + 0x04, /**< Set blocking synchronization as default scheduling */ +} cl_event_flags; + +/** + * NVCL context scheduling flags + */ +typedef enum cl_context_flags_enum { + NVCL_CTX_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + NVCL_CTX_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + NVCL_CTX_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + NVCL_CTX_SCHED_BLOCKING_SYNC = + 0x04, /**< Set blocking synchronization as default scheduling */ +} cl_context_flags; + +/** + * Stream creation flags + */ +typedef enum CUstream_flags_enum { + CU_STREAM_DEFAULT = 0x0, /**< Default stream flag */ + CU_STREAM_NON_BLOCKING = + 0x1 /**< Stream does not synchronize with stream 0 (the NULL stream) */ +} CUstream_flags; + +/** + * Legacy stream handle + * + * Stream handle that can be passed as a CUstream to use an implicit stream + * with legacy synchronization behavior. + * + * See details of the \link_sync_behavior + */ +#define CU_STREAM_LEGACY ((CUstream)0x1) + +/** + * Per-thread stream handle + * + * Stream handle that can be passed as a CUstream to use an implicit stream + * with per-thread synchronization behavior. + * + * See details of the \link_sync_behavior + */ +#define CU_STREAM_PER_THREAD ((CUstream)0x2) + +/** + * Event creation flags + */ +typedef enum CUevent_flags_enum { + CU_EVENT_DEFAULT = 0x0, /**< Default event flag */ + CU_EVENT_BLOCKING_SYNC = 0x1, /**< Event uses blocking synchronization */ + CU_EVENT_DISABLE_TIMING = 0x2, /**< Event will not record timing data */ + CU_EVENT_INTERPROCESS = 0x4 /**< Event is suitable for interprocess use. + CU_EVENT_DISABLE_TIMING must be set */ +} CUevent_flags; + +/** + * Event record flags + */ +typedef enum CUevent_record_flags_enum { + CU_EVENT_RECORD_DEFAULT = 0x0, /**< Default event record flag */ + CU_EVENT_RECORD_EXTERNAL = + 0x1 /**< When using stream capture, create an event record node + * instead of the default behavior. This flag is invalid + * when used outside of capture. */ +} CUevent_record_flags; + +/** + * Event wait flags + */ +typedef enum CUevent_wait_flags_enum { + CU_EVENT_WAIT_DEFAULT = 0x0, /**< Default event wait flag */ + CU_EVENT_WAIT_EXTERNAL = + 0x1 /**< When using stream capture, create an event wait node + * instead of the default behavior. This flag is invalid + * when used outside of capture.*/ +} CUevent_wait_flags; + +/** + * Flags for ::cuStreamWaitValue32 and ::cuStreamWaitValue64 + */ +typedef enum CUstreamWaitValue_flags_enum { + CU_STREAM_WAIT_VALUE_GEQ = + 0x0, /**< Wait until (int32_t)(*addr - value) >= 0 (or int64_t for 64 bit + values). Note this is a cyclic comparison which ignores + wraparound. (Default behavior.) */ + CU_STREAM_WAIT_VALUE_EQ = 0x1, /**< Wait until *addr == value. */ + CU_STREAM_WAIT_VALUE_AND = 0x2, /**< Wait until (*addr & value) != 0. */ + CU_STREAM_WAIT_VALUE_NOR = + 0x3, /**< Wait until ~(*addr | value) != 0. Support for this operation can + be queried with ::cuDeviceGetAttribute() and + ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR.*/ + CU_STREAM_WAIT_VALUE_FLUSH = + 1 << 30 /**< Follow the wait operation with a flush of outstanding remote + writes. This means that, if a remote write operation is + guaranteed to have reached the device before the wait can be + satisfied, that write is guaranteed to be visible to downstream + device work. The device is permitted to reorder remote writes + internally. For example, this flag would be required if two + remote writes arrive in a defined order, the wait is satisfied + by the second write, and downstream work needs to observe the + first write. Support for this operation is restricted to + selected platforms and can be queried with + ::CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES.*/ +} CUstreamWaitValue_flags; + +/** + * Flags for ::cuStreamWriteValue32 + */ +typedef enum CUstreamWriteValue_flags_enum { + CU_STREAM_WRITE_VALUE_DEFAULT = 0x0, /**< Default behavior */ + CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER = + 0x1 /**< Permits the write to be reordered with writes which were issued + before it, as a performance optimization. Normally, + ::cuStreamWriteValue32 will provide a memory fence before the + write, which has similar semantics to + __threadfence_system() but is scoped to the stream + rather than a CUDA thread. + This flag is not supported in the v2 API. */ +} CUstreamWriteValue_flags; + +/** + * Operations for ::cuStreamBatchMemOp + */ +typedef enum CUstreamBatchMemOpType_enum { + CU_STREAM_MEM_OP_WAIT_VALUE_32 = + 1, /**< Represents a ::cuStreamWaitValue32 operation */ + CU_STREAM_MEM_OP_WRITE_VALUE_32 = + 2, /**< Represents a ::cuStreamWriteValue32 operation */ + CU_STREAM_MEM_OP_WAIT_VALUE_64 = + 4, /**< Represents a ::cuStreamWaitValue64 operation */ + CU_STREAM_MEM_OP_WRITE_VALUE_64 = + 5, /**< Represents a ::cuStreamWriteValue64 operation */ + CU_STREAM_MEM_OP_BARRIER = + 6, /**< Insert a memory barrier of the specified type */ + CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES = + 3 /**< This has the same effect as ::CU_STREAM_WAIT_VALUE_FLUSH, but as a + standalone operation. */ +} CUstreamBatchMemOpType; + +/** + * Flags for ::cuStreamMemoryBarrier + */ +typedef enum CUstreamMemoryBarrier_flags_enum { + CU_STREAM_MEMORY_BARRIER_TYPE_SYS = 0x0, /**< System-wide memory barrier. */ + CU_STREAM_MEMORY_BARRIER_TYPE_GPU = + 0x1 /**< Limit memory barrier scope to the GPU. */ +} CUstreamMemoryBarrier_flags; + +/** + * Per-operation parameters for ::cuStreamBatchMemOp + */ +typedef union CUstreamBatchMemOpParams_union { + CUstreamBatchMemOpType operation; + struct CUstreamMemOpWaitValueParams_st { + CUstreamBatchMemOpType operation; + CUdeviceptr address; + union { + cuuint32_t value; + cuuint64_t value64; + }; + unsigned int flags; + CUdeviceptr + alias; /**< For driver internal use. Initial value is unimportant. */ + } waitValue; + struct CUstreamMemOpWriteValueParams_st { + CUstreamBatchMemOpType operation; + CUdeviceptr address; + union { + cuuint32_t value; + cuuint64_t value64; + }; + unsigned int flags; + CUdeviceptr + alias; /**< For driver internal use. Initial value is unimportant. */ + } writeValue; + struct CUstreamMemOpFlushRemoteWritesParams_st { + CUstreamBatchMemOpType operation; + unsigned int flags; + } flushRemoteWrites; + struct CUstreamMemOpMemoryBarrierParams_st { /**< Only supported in the _v2 + API */ + CUstreamBatchMemOpType operation; + unsigned int flags; + } memoryBarrier; + cuuint64_t pad[6]; +} CUstreamBatchMemOpParams_v1; +typedef CUstreamBatchMemOpParams_v1 CUstreamBatchMemOpParams; + +typedef struct CUDA_BATCH_MEM_OP_NODE_PARAMS_v1_st { + CUcontext ctx; + unsigned int count; + CUstreamBatchMemOpParams *paramArray; + unsigned int flags; +} CUDA_BATCH_MEM_OP_NODE_PARAMS_v1; +typedef CUDA_BATCH_MEM_OP_NODE_PARAMS_v1 CUDA_BATCH_MEM_OP_NODE_PARAMS; + +/** + * Batch memory operation node parameters + */ +typedef struct CUDA_BATCH_MEM_OP_NODE_PARAMS_v2_st { + CUcontext ctx; /**< Context to use for the operations. */ + unsigned int count; /**< Number of operations in paramArray. */ + CUstreamBatchMemOpParams + *paramArray; /**< Array of batch memory operations. */ + unsigned int flags; /**< Flags to control the node. */ +} CUDA_BATCH_MEM_OP_NODE_PARAMS_v2; + +/** + * Occupancy calculator flag + */ +typedef enum CUoccupancy_flags_enum { + CU_OCCUPANCY_DEFAULT = 0x0, /**< Default behavior */ + CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE = + 0x1 /**< Assume global caching is enabled and cannot be automatically + turned off */ +} CUoccupancy_flags; + +/** + * Flags for ::cuStreamUpdateCaptureDependencies + */ +typedef enum CUstreamUpdateCaptureDependencies_flags_enum { + CU_STREAM_ADD_CAPTURE_DEPENDENCIES = + 0x0, /**< Add new nodes to the dependency set */ + CU_STREAM_SET_CAPTURE_DEPENDENCIES = + 0x1 /**< Replace the dependency set with the new nodes */ +} CUstreamUpdateCaptureDependencies_flags; + +/** + * Types of async notification that can be sent + */ +typedef enum CUasyncNotificationType_enum { + CU_ASYNC_NOTIFICATION_TYPE_OVER_BUDGET = 0x1 +} CUasyncNotificationType; + +/** + * Information passed to the user via the async notification callback + */ +typedef struct CUasyncNotificationInfo_st { + CUasyncNotificationType type; + union { + struct { + unsigned long long bytesOverBudget; + } overBudget; + } info; +} CUasyncNotificationInfo; + +/** + * CUDA async notification callback + * \param info Information describing what actions to take as a result of this + * trim notification. \param userData Pointer to user defined data provided at + * registration. \param callback The callback handle associated with this + * specific callback. + */ +typedef void (*CUasyncCallback)(CUasyncNotificationInfo *info, void *userData, + CUasyncCallbackHandle callback); + +/** + * Array formats + */ +typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, /**< Unsigned 8-bit integers */ + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, /**< Unsigned 16-bit integers */ + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, /**< Unsigned 32-bit integers */ + CU_AD_FORMAT_SIGNED_INT8 = 0x08, /**< Signed 8-bit integers */ + CU_AD_FORMAT_SIGNED_INT16 = 0x09, /**< Signed 16-bit integers */ + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */ + CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */ + CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */ + CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */ + CU_AD_FORMAT_UNORM_INT8X1 = + 0xc0, /**< 1 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT8X2 = + 0xc1, /**< 2 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT8X4 = + 0xc2, /**< 4 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X1 = + 0xc3, /**< 1 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X2 = + 0xc4, /**< 2 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X4 = + 0xc5, /**< 4 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X1 = + 0xc6, /**< 1 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X2 = + 0xc7, /**< 2 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X4 = + 0xc8, /**< 4 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X1 = + 0xc9, /**< 1 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X2 = + 0xca, /**< 2 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X4 = + 0xcb, /**< 4 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized + block-compressed (BC1 compression) format */ + CU_AD_FORMAT_BC1_UNORM_SRGB = + 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 + compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized + block-compressed (BC2 compression) format */ + CU_AD_FORMAT_BC2_UNORM_SRGB = + 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 + compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized + block-compressed (BC3 compression) format */ + CU_AD_FORMAT_BC3_UNORM_SRGB = + 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 + compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized + block-compressed (BC4 compression) format */ + CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized + block-compressed (BC4 compression) format */ + CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized + block-compressed (BC5 compression) format */ + CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized + block-compressed (BC5 compression) format */ + CU_AD_FORMAT_BC6H_UF16 = + 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H + compression) format */ + CU_AD_FORMAT_BC6H_SF16 = + 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) + format */ + CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized + block-compressed (BC7 compression) format */ + CU_AD_FORMAT_BC7_UNORM_SRGB = + 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) + format with sRGB encoding */ +} CUarray_format; + +/** + * Texture reference addressing modes + */ +typedef enum CUaddress_mode_enum { + CU_TR_ADDRESS_MODE_WRAP = 0, /**< Wrapping address mode */ + CU_TR_ADDRESS_MODE_CLAMP = 1, /**< Clamp to edge address mode */ + CU_TR_ADDRESS_MODE_MIRROR = 2, /**< Mirror address mode */ + CU_TR_ADDRESS_MODE_BORDER = 3 /**< Border address mode */ +} CUaddress_mode; + +/** + * Texture reference filtering modes + */ +typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, /**< Point filter mode */ + CU_TR_FILTER_MODE_LINEAR = 1 /**< Linear filter mode */ +} CUfilter_mode; + +/** + * Device properties + */ +typedef enum CUdevice_attribute_enum { + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = + 1, /**< Maximum number of threads per block */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2, /**< Maximum block dimension X */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y = 3, /**< Maximum block dimension Y */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z = 4, /**< Maximum block dimension Z */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X = 5, /**< Maximum grid dimension X */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y = 6, /**< Maximum grid dimension Y */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z = 7, /**< Maximum grid dimension Z */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = + 8, /**< Maximum shared memory available per block in bytes */ + CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK = + 8, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK */ + CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY = + 9, /**< Memory available on device for __constant__ variables in a CUDA C + kernel in bytes */ + CU_DEVICE_ATTRIBUTE_WARP_SIZE = 10, /**< Warp size in threads */ + CU_DEVICE_ATTRIBUTE_MAX_PITCH = + 11, /**< Maximum pitch in bytes allowed by memory copies */ + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = + 12, /**< Maximum number of 32-bit registers available per block */ + CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK = + 12, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK */ + CU_DEVICE_ATTRIBUTE_CLOCK_RATE = + 13, /**< Typical clock frequency in kilohertz */ + CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT = + 14, /**< Alignment requirement for textures */ + CU_DEVICE_ATTRIBUTE_GPU_OVERLAP = + 15, /**< Device can possibly copy memory and execute a kernel + concurrently. Deprecated. Use instead + CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT. */ + CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = + 16, /**< Number of multiprocessors on device */ + CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT = + 17, /**< Specifies whether there is a run time limit on kernels */ + CU_DEVICE_ATTRIBUTE_INTEGRATED = + 18, /**< Device is integrated with host memory */ + CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY = + 19, /**< Device can map host memory into CUDA address space */ + CU_DEVICE_ATTRIBUTE_COMPUTE_MODE = + 20, /**< Compute mode (See ::CUcomputemode for details) */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH = + 21, /**< Maximum 1D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH = + 22, /**< Maximum 2D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT = + 23, /**< Maximum 2D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH = + 24, /**< Maximum 3D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT = + 25, /**< Maximum 3D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH = + 26, /**< Maximum 3D texture depth */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH = + 27, /**< Maximum 2D layered texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT = + 28, /**< Maximum 2D layered texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS = + 29, /**< Maximum layers in a 2D layered texture */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH = + 27, /**< Deprecated, use + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT = + 28, /**< Deprecated, use + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES = + 29, /**< Deprecated, use + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS */ + CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT = + 30, /**< Alignment requirement for surfaces */ + CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS = + 31, /**< Device can possibly execute multiple kernels concurrently */ + CU_DEVICE_ATTRIBUTE_ECC_ENABLED = 32, /**< Device has ECC support enabled */ + CU_DEVICE_ATTRIBUTE_PCI_BUS_ID = 33, /**< PCI bus ID of the device */ + CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID = 34, /**< PCI device ID of the device */ + CU_DEVICE_ATTRIBUTE_TCC_DRIVER = 35, /**< Device is using TCC driver model */ + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = + 36, /**< Peak memory clock frequency in kilohertz */ + CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = + 37, /**< Global memory bus width in bits */ + CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38, /**< Size of L2 cache in bytes */ + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = + 39, /**< Maximum resident threads per multiprocessor */ + CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT = + 40, /**< Number of asynchronous engines */ + CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = + 41, /**< Device shares a unified address space with the host */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH = + 42, /**< Maximum 1D layered texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS = + 43, /**< Maximum layers in a 1D layered texture */ + CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER = 44, /**< Deprecated, do not use. */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH = + 45, /**< Maximum 2D texture width if CUDA_ARRAY3D_TEXTURE_GATHER is set */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT = + 46, /**< Maximum 2D texture height if CUDA_ARRAY3D_TEXTURE_GATHER is set + */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE = + 47, /**< Alternate maximum 3D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE = + 48, /**< Alternate maximum 3D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE = + 49, /**< Alternate maximum 3D texture depth */ + CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID = 50, /**< PCI domain ID of the device */ + CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT = + 51, /**< Pitch alignment requirement for textures */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH = + 52, /**< Maximum cubemap texture width/height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH = + 53, /**< Maximum cubemap layered texture width/height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS = + 54, /**< Maximum layers in a cubemap layered texture */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH = + 55, /**< Maximum 1D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH = + 56, /**< Maximum 2D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT = + 57, /**< Maximum 2D surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH = + 58, /**< Maximum 3D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT = + 59, /**< Maximum 3D surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH = + 60, /**< Maximum 3D surface depth */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH = + 61, /**< Maximum 1D layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS = + 62, /**< Maximum layers in a 1D layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH = + 63, /**< Maximum 2D layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT = + 64, /**< Maximum 2D layered surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS = + 65, /**< Maximum layers in a 2D layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH = + 66, /**< Maximum cubemap surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH = + 67, /**< Maximum cubemap layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS = + 68, /**< Maximum layers in a cubemap layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH = + 69, /**< Deprecated, do not use. Use + cudaDeviceGetTexture1DLinearMaxWidth() or + cuDeviceGetTexture1DLinearMaxWidth() instead. */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH = + 70, /**< Maximum 2D linear texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT = + 71, /**< Maximum 2D linear texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH = + 72, /**< Maximum 2D linear texture pitch in bytes */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH = + 73, /**< Maximum mipmapped 2D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT = + 74, /**< Maximum mipmapped 2D texture height */ + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = + 75, /**< Major compute capability version number */ + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = + 76, /**< Minor compute capability version number */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH = + 77, /**< Maximum mipmapped 1D texture width */ + CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED = + 78, /**< Device supports stream priorities */ + CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED = + 79, /**< Device supports caching globals in L1 */ + CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED = + 80, /**< Device supports caching locals in L1 */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR = + 81, /**< Maximum shared memory available per multiprocessor in bytes */ + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR = + 82, /**< Maximum number of 32-bit registers available per multiprocessor + */ + CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY = + 83, /**< Device can allocate managed memory on this system */ + CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD = + 84, /**< Device is on a multi-GPU board */ + CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID = + 85, /**< Unique id for a group of devices on the same multi-GPU board */ + CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED = + 86, /**< Link between the device and the host supports native atomic + operations (this is a placeholder attribute, and is not supported + on any current hardware)*/ + CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO = + 87, /**< Ratio of single precision performance (in floating-point + operations per second) to double precision performance */ + CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS = + 88, /**< Device supports coherently accessing pageable memory without + calling cudaHostRegister on it */ + CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS = + 89, /**< Device can coherently access managed memory concurrently with the + CPU */ + CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED = + 90, /**< Device supports compute preemption. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM = + 91, /**< Device can access host registered memory at the same virtual + address as the CPU */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS_V1 = + 92, /**< Deprecated, along with v1 MemOps API, ::cuStreamBatchMemOp and + related APIs are supported. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS_V1 = + 93, /**< Deprecated, along with v1 MemOps API, 64-bit operations are + supported in ::cuStreamBatchMemOp and related APIs. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR_V1 = + 94, /**< Deprecated, along with v1 MemOps API, ::CU_STREAM_WAIT_VALUE_NOR + is supported. */ + CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH = + 95, /**< Device supports launching cooperative kernels via + ::cuLaunchCooperativeKernel */ + CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH = + 96, /**< Deprecated, ::cuLaunchCooperativeKernelMultiDevice is deprecated. + */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = + 97, /**< Maximum option shared memory per block */ + CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES = + 98, /**< The ::CU_STREAM_WAIT_VALUE_FLUSH flag and the + ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on the + device. See \ref CUDA_MEMOP for additional details. */ + CU_DEVICE_ATTRIBUTE_HOST_REGISTER_SUPPORTED = + 99, /**< Device supports host memory registration via ::cudaHostRegister. + */ + CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES = + 100, /**< Device accesses pageable memory via the host's page tables. */ + CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST = + 101, /**< The host can directly access managed memory on the device + without migration. */ + CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED = + 102, /**< Deprecated, Use + CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED*/ + CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = + 102, /**< Device supports virtual memory management APIs like + ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs + */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = + 103, /**< Device supports exporting memory to a posix file descriptor with + ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = + 104, /**< Device supports exporting memory to a Win32 NT handle with + ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = + 105, /**< Device supports exporting memory to a Win32 KMT handle with + ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = + 106, /**< Maximum number of blocks per multiprocessor */ + CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = + 107, /**< Device supports compression of memory */ + CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = + 108, /**< Maximum L2 persisting lines capacity setting in bytes. */ + CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE = + 109, /**< Maximum value of CUaccessPolicyWindow::num_bytes. */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = + 110, /**< Device supports specifying the GPUDirect RDMA flag with + ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = + 111, /**< Shared memory reserved by CUDA driver per block in bytes */ + CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = + 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped + arrays */ + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = + 113, /**< Device supports using the ::cuMemHostRegister flag + ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be + mapped as read-only to the GPU */ + CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = + 114, /**< External timeline semaphore interop is supported on the device + */ + CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = + 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool + family of APIs */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = + 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages + (see https://docs.nvidia.com/cuda/gpudirect-rdma for more + information) */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS = + 117, /**< The returned attribute shall be interpreted as a bitmask, where + the individual bits are described by the + ::CUflushGPUDirectRDMAWritesOptions enum */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING = + 118, /**< GPUDirect RDMA writes to the device do not need to be flushed + for consumers within the scope indicated by the returned + attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical + values returned here. */ + CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES = + 119, /**< Handle types supported with mempool based IPC */ + CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = + 120, /**< Indicates device supports cluster launch */ + CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED = + 121, /**< Device supports deferred mapping CUDA arrays and CUDA mipmapped + arrays */ + CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = + 122, /**< 64-bit operations are supported in ::cuStreamBatchMemOp and + related MemOp APIs. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = + 123, /**< ::CU_STREAM_WAIT_VALUE_NOR is supported by MemOp APIs. */ + CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED = + 124, /**< Device supports buffer sharing with dma_buf mechanism. */ + CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED = + 125, /**< Device supports IPC Events. */ + CU_DEVICE_ATTRIBUTE_MEM_SYNC_DOMAIN_COUNT = + 126, /**< Number of memory domains the device supports. */ + CU_DEVICE_ATTRIBUTE_TENSOR_MAP_ACCESS_SUPPORTED = + 127, /**< Device supports accessing memory using Tensor Map. */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED = + 128, /**< Device supports exporting memory to a fabric handle with + cuMemExportToShareableHandle() or requested with cuMemCreate() */ + CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS = + 129, /**< Device supports unified function pointers. */ + CU_DEVICE_ATTRIBUTE_NUMA_CONFIG = 130, + CU_DEVICE_ATTRIBUTE_NUMA_ID = 131, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED = + 132, /**< Device supports switch multicast and reduction operations. */ + CU_DEVICE_ATTRIBUTE_MPS_ENABLED = + 133, /**< Indicates if contexts created on this device will be shared via + MPS */ + CU_DEVICE_ATTRIBUTE_HOST_NUMA_ID = + 134, /**< NUMA ID of the host node closest to the device. Returns -1 when + system does not support NUMA. */ + CU_DEVICE_ATTRIBUTE_MAX +} CUdevice_attribute; + +/** + * Legacy device properties + */ +typedef struct CUdevprop_st { + int maxThreadsPerBlock; /**< Maximum number of threads per block */ + int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ + int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ + int sharedMemPerBlock; /**< Shared memory available per block in bytes */ + int totalConstantMemory; /**< Constant memory available on device in bytes */ + int SIMDWidth; /**< Warp size in threads */ + int memPitch; /**< Maximum pitch in bytes allowed by memory copies */ + int regsPerBlock; /**< 32-bit registers available per block */ + int clockRate; /**< Clock frequency in kilohertz */ + int textureAlign; /**< Alignment requirement for textures */ +} CUdevprop_v1; +typedef CUdevprop_v1 CUdevprop; + +/** + * Pointer information + */ +typedef enum CUpointer_attribute_enum { + CU_POINTER_ATTRIBUTE_CONTEXT = + 1, /**< The ::CUcontext on which a pointer was allocated or registered */ + CU_POINTER_ATTRIBUTE_MEMORY_TYPE = 2, /**< The ::CUmemorytype describing the + physical location of a pointer */ + CU_POINTER_ATTRIBUTE_DEVICE_POINTER = + 3, /**< The address at which a pointer's memory may be accessed on the + device */ + CU_POINTER_ATTRIBUTE_HOST_POINTER = + 4, /**< The address at which a pointer's memory may be accessed on the + host */ + CU_POINTER_ATTRIBUTE_P2P_TOKENS = 5, /**< A pair of tokens for use with the + nv-p2p.h Linux kernel interface */ + CU_POINTER_ATTRIBUTE_SYNC_MEMOPS = + 6, /**< Synchronize every synchronous memory operation initiated on this + region */ + CU_POINTER_ATTRIBUTE_BUFFER_ID = + 7, /**< A process-wide unique ID for an allocated memory region*/ + CU_POINTER_ATTRIBUTE_IS_MANAGED = + 8, /**< Indicates if the pointer points to managed memory */ + CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = + 9, /**< A device ordinal of a device on which a pointer was allocated or + registered */ + CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE = + 10, /**< 1 if this pointer maps to an allocation that is suitable for + ::cudaIpcGetMemHandle, 0 otherwise **/ + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR = + 11, /**< Starting address for this requested pointer */ + CU_POINTER_ATTRIBUTE_RANGE_SIZE = + 12, /**< Size of the address range for this requested pointer */ + CU_POINTER_ATTRIBUTE_MAPPED = + 13, /**< 1 if this pointer is in a valid address range that is mapped to a + backing allocation, 0 otherwise **/ + CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES = + 14, /**< Bitmask of allowed ::CUmemAllocationHandleType for this + allocation **/ + CU_POINTER_ATTRIBUTE_IS_GPU_DIRECT_RDMA_CAPABLE = + 15, /**< 1 if the memory this pointer is referencing can be used with the + GPUDirect RDMA API **/ + CU_POINTER_ATTRIBUTE_ACCESS_FLAGS = + 16, /**< Returns the access flags the device associated with the current + context has on the corresponding memory referenced by the pointer + given */ + CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE = + 17, /**< Returns the mempool handle for the allocation if it was allocated + from a mempool. Otherwise returns NULL. **/ + CU_POINTER_ATTRIBUTE_MAPPING_SIZE = + 18, /**< Size of the actual underlying mapping that the pointer belongs to + **/ + CU_POINTER_ATTRIBUTE_MAPPING_BASE_ADDR = + 19, /**< The start address of the mapping that the pointer belongs to **/ + CU_POINTER_ATTRIBUTE_MEMORY_BLOCK_ID = + 20 /**< A process-wide unique id corresponding to the physical allocation + the pointer belongs to **/ +} CUpointer_attribute; + +/** + * Function properties + */ +typedef enum CUfunction_attribute_enum { + /** + * The maximum number of threads per block, beyond which a launch of the + * function would fail. This number depends on both the function and the + * device on which the function is currently loaded. + */ + CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0, + + /** + * The size in bytes of statically-allocated shared memory required by + * this function. This does not include dynamically-allocated shared + * memory requested by the user at runtime. + */ + CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES = 1, + + /** + * The size in bytes of user-allocated constant memory required by this + * function. + */ + CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES = 2, + + /** + * The size in bytes of local memory used by each thread of this function. + */ + CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES = 3, + + /** + * The number of registers used by each thread of this function. + */ + CU_FUNC_ATTRIBUTE_NUM_REGS = 4, + + /** + * The PTX virtual architecture version for which the function was + * compiled. This value is the major PTX version * 10 + the minor PTX + * version, so a PTX version 1.3 function would return the value 13. + * Note that this may return the undefined value of 0 for cubins + * compiled prior to CUDA 3.0. + */ + CU_FUNC_ATTRIBUTE_PTX_VERSION = 5, + + /** + * The binary architecture version for which the function was compiled. + * This value is the major binary version * 10 + the minor binary version, + * so a binary version 1.3 function would return the value 13. Note that + * this will return a value of 10 for legacy cubins that do not have a + * properly-encoded binary architecture version. + */ + CU_FUNC_ATTRIBUTE_BINARY_VERSION = 6, + + /** + * The attribute to indicate whether the function has been compiled with + * user specified option "-Xptxas --dlcm=ca" set . + */ + CU_FUNC_ATTRIBUTE_CACHE_MODE_CA = 7, + + /** + * The maximum size in bytes of dynamically-allocated shared memory that can + * be used by this function. If the user-specified dynamic shared memory size + * is larger than this value, the launch will fail. See ::cuFuncSetAttribute, + * ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8, + + /** + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets the shared memory carveout preference, in percent of + * the total shared memory. Refer to + * ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR. This is only a + * hint, and the driver can choose a different ratio if required to execute + * the function. See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT = 9, + + /** + * If this attribute is set, the kernel must launch with a valid cluster + * size specified. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET = 10, + + /** + * The required cluster width in blocks. The values must either all be 0 or + * all be positive. The validity of the cluster dimensions is otherwise + * checked at launch time. + * + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH = 11, + + /** + * The required cluster height in blocks. The values must either all be 0 or + * all be positive. The validity of the cluster dimensions is otherwise + * checked at launch time. + * + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime should return CUDA_ERROR_NOT_PERMITTED. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT = 12, + + /** + * The required cluster depth in blocks. The values must either all be 0 or + * all be positive. The validity of the cluster dimensions is otherwise + * checked at launch time. + * + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime should return CUDA_ERROR_NOT_PERMITTED. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH = 13, + + /** + * Whether the function can be launched with non-portable cluster size. 1 is + * allowed, 0 is disallowed. A non-portable cluster size may only function + * on the specific SKUs the program is tested on. The launch might fail if + * the program is run on a different hardware platform. + * + * CUDA API provides cudaOccupancyMaxActiveClusters to assist with checking + * whether the desired size can be launched on the current device. + * + * Portable Cluster Size + * + * A portable cluster size is guaranteed to be functional on all compute + * capabilities higher than the target compute capability. The portable + * cluster size for sm_90 is 8 blocks per cluster. This value may increase + * for future compute capabilities. + * + * The specific hardware unit may support higher cluster sizes that’s not + * guaranteed to be portable. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED = 14, + + /** + * The block scheduling policy of a function. The value type is + * CUclusterSchedulingPolicy / cudaClusterSchedulingPolicy. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = 15, + + CU_FUNC_ATTRIBUTE_MAX +} CUfunction_attribute; + +/** + * Function cache configurations + */ +typedef enum CUfunc_cache_enum { + CU_FUNC_CACHE_PREFER_NONE = + 0x00, /**< no preference for shared memory or L1 (default) */ + CU_FUNC_CACHE_PREFER_SHARED = + 0x01, /**< prefer larger shared memory and smaller L1 cache */ + CU_FUNC_CACHE_PREFER_L1 = + 0x02, /**< prefer larger L1 cache and smaller shared memory */ + CU_FUNC_CACHE_PREFER_EQUAL = + 0x03 /**< prefer equal sized L1 cache and shared memory */ +} CUfunc_cache; + +/** + * \deprecated + * + * Shared memory configurations + */ +typedef enum CUsharedconfig_enum { + CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE = + 0x00, /**< set default shared memory bank size */ + CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE = + 0x01, /**< set shared memory bank width to four bytes */ + CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE = + 0x02 /**< set shared memory bank width to eight bytes */ +} CUsharedconfig; + +/** + * Shared memory carveout configurations. These may be passed to + * ::cuFuncSetAttribute or ::cuKernelSetAttribute + */ +typedef enum CUshared_carveout_enum { + CU_SHAREDMEM_CARVEOUT_DEFAULT = + -1, /**< No preference for shared memory or L1 (default) */ + CU_SHAREDMEM_CARVEOUT_MAX_SHARED = + 100, /**< Prefer maximum available shared memory, minimum L1 cache */ + CU_SHAREDMEM_CARVEOUT_MAX_L1 = + 0 /**< Prefer maximum available L1 cache, minimum shared memory */ +} CUshared_carveout; + +/** + * Memory types + */ +typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, /**< Host memory */ + CU_MEMORYTYPE_DEVICE = 0x02, /**< Device memory */ + CU_MEMORYTYPE_ARRAY = 0x03, /**< Array memory */ + CU_MEMORYTYPE_UNIFIED = 0x04 /**< Unified device or host memory */ +} CUmemorytype; + +/** + * Compute Modes + */ +typedef enum CUcomputemode_enum { + CU_COMPUTEMODE_DEFAULT = + 0, /**< Default compute mode (Multiple contexts allowed per device) */ + CU_COMPUTEMODE_PROHIBITED = 2, /**< Compute-prohibited mode (No contexts can + be created on this device at this time) */ + CU_COMPUTEMODE_EXCLUSIVE_PROCESS = + 3 /**< Compute-exclusive-process mode (Only one context used by a single + process can be present on this device at a time) */ +} CUcomputemode; + +/** + * Memory advise values + */ +typedef enum CUmem_advise_enum { + CU_MEM_ADVISE_SET_READ_MOSTLY = + 1, /**< Data will mostly be read and only occasionally be written to */ + CU_MEM_ADVISE_UNSET_READ_MOSTLY = + 2, /**< Undo the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY */ + CU_MEM_ADVISE_SET_PREFERRED_LOCATION = + 3, /**< Set the preferred location for the data as the specified device */ + CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION = + 4, /**< Clear the preferred location for the data */ + CU_MEM_ADVISE_SET_ACCESSED_BY = + 5, /**< Data will be accessed by the specified device, so prevent page + faults as much as possible */ + CU_MEM_ADVISE_UNSET_ACCESSED_BY = + 6 /**< Let the Unified Memory subsystem decide on the page faulting policy + for the specified device */ +} CUmem_advise; + +typedef enum CUmem_range_attribute_enum { + CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = + 1, /**< Whether the range will mostly be read and only occasionally be + written to */ + CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION = + 2, /**< The preferred location of the range */ + CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY = + 3, /**< Memory range has ::CU_MEM_ADVISE_SET_ACCESSED_BY set for specified + device */ + CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION = + 4 /**< The last location to which the range was prefetched */ + , + CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_TYPE = + 5 /**< The preferred location type of the range */ + , + CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_ID = + 6 /**< The preferred location id of the range */ + , + CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_TYPE = + 7 /**< The last location type to which the range was prefetched */ + , + CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_ID = + 8 /**< The last location id to which the range was prefetched */ +} CUmem_range_attribute; + +/** + * Online compiler and linker options + */ +typedef enum CUjit_option_enum { + /** + * Max number of registers that a thread may use.\n + * Option type: unsigned int\n + * Applies to: compiler only + */ + CU_JIT_MAX_REGISTERS = 0, + + /** + * IN: Specifies minimum number of threads per block to target compilation + * for\n + * OUT: Returns the number of threads the compiler actually targeted. + * This restricts the resource utilization of the compiler (e.g. max + * registers) such that a block with the given number of threads should be + * able to launch based on register limitations. Note, this option does not + * currently take into account any other resource limitations, such as + * shared memory utilization.\n + * Cannot be combined with ::CU_JIT_TARGET.\n + * Option type: unsigned int\n + * Applies to: compiler only + */ + CU_JIT_THREADS_PER_BLOCK = 1, + + /** + * Overwrites the option value with the total wall clock time, in + * milliseconds, spent in the compiler and linker\n + * Option type: float\n + * Applies to: compiler and linker + */ + CU_JIT_WALL_TIME = 2, + + /** + * Pointer to a buffer in which to print any log messages + * that are informational in nature (the buffer size is specified via + * option ::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES)\n + * Option type: char *\n + * Applies to: compiler and linker + */ + CU_JIT_INFO_LOG_BUFFER = 3, + + /** + * IN: Log buffer size in bytes. Log messages will be capped at this size + * (including null terminator)\n + * OUT: Amount of log buffer filled with messages\n + * Option type: unsigned int\n + * Applies to: compiler and linker + */ + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES = 4, + + /** + * Pointer to a buffer in which to print any log messages that + * reflect errors (the buffer size is specified via option + * ::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES)\n + * Option type: char *\n + * Applies to: compiler and linker + */ + CU_JIT_ERROR_LOG_BUFFER = 5, + + /** + * IN: Log buffer size in bytes. Log messages will be capped at this size + * (including null terminator)\n + * OUT: Amount of log buffer filled with messages\n + * Option type: unsigned int\n + * Applies to: compiler and linker + */ + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES = 6, + + /** + * Level of optimizations to apply to generated code (0 - 4), with 4 + * being the default and highest level of optimizations.\n + * Option type: unsigned int\n + * Applies to: compiler only + */ + CU_JIT_OPTIMIZATION_LEVEL = 7, + + /** + * No option value required. Determines the target based on the current + * attached context (default)\n + * Option type: No option value needed\n + * Applies to: compiler and linker + */ + CU_JIT_TARGET_FROM_CUCONTEXT = 8, + + /** + * Target is chosen based on supplied ::CUjit_target. Cannot be + * combined with ::CU_JIT_THREADS_PER_BLOCK.\n + * Option type: unsigned int for enumerated type ::CUjit_target\n + * Applies to: compiler and linker + */ + CU_JIT_TARGET = 9, + + /** + * Specifies choice of fallback strategy if matching cubin is not found. + * Choice is based on supplied ::CUjit_fallback. This option cannot be + * used with cuLink* APIs as the linker requires exact matches.\n + * Option type: unsigned int for enumerated type ::CUjit_fallback\n + * Applies to: compiler only + */ + CU_JIT_FALLBACK_STRATEGY = 10, + + /** + * Specifies whether to create debug information in output (-g) + * (0: false, default)\n + * Option type: int\n + * Applies to: compiler and linker + */ + CU_JIT_GENERATE_DEBUG_INFO = 11, + + /** + * Generate verbose log messages (0: false, default)\n + * Option type: int\n + * Applies to: compiler and linker + */ + CU_JIT_LOG_VERBOSE = 12, + + /** + * Generate line number information (-lineinfo) (0: false, default)\n + * Option type: int\n + * Applies to: compiler only + */ + CU_JIT_GENERATE_LINE_INFO = 13, + + /** + * Specifies whether to enable caching explicitly (-dlcm) \n + * Choice is based on supplied ::CUjit_cacheMode_enum.\n + * Option type: unsigned int for enumerated type ::CUjit_cacheMode_enum\n + * Applies to: compiler only + */ + CU_JIT_CACHE_MODE = 14, + + /** + * \deprecated + * This jit option is deprecated and should not be used. + */ + CU_JIT_NEW_SM3X_OPT = 15, + + /** + * This jit option is used for internal purpose only. + */ + CU_JIT_FAST_COMPILE = 16, + + /** + * Array of device symbol names that will be relocated to the corresponding + * host addresses stored in ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES.\n + * Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n + * When loading a device module, driver will relocate all encountered + * unresolved symbols to the host addresses.\n + * It is only allowed to register symbols that correspond to unresolved + * global variables.\n + * It is illegal to register the same device symbol at multiple addresses.\n + * Option type: const char **\n + * Applies to: dynamic linker only + */ + CU_JIT_GLOBAL_SYMBOL_NAMES = 17, + + /** + * Array of host addresses that will be used to relocate corresponding + * device symbols stored in ::CU_JIT_GLOBAL_SYMBOL_NAMES.\n + * Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n + * Option type: void **\n + * Applies to: dynamic linker only + */ + CU_JIT_GLOBAL_SYMBOL_ADDRESSES = 18, + + /** + * Number of entries in ::CU_JIT_GLOBAL_SYMBOL_NAMES and + * ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES arrays.\n + * Option type: unsigned int\n + * Applies to: dynamic linker only + */ + CU_JIT_GLOBAL_SYMBOL_COUNT = 19, + + /** + * \deprecated + * Enable link-time optimization (-dlto) for device code (Disabled by + * default).\n This option is not supported on 32-bit platforms.\n Option + * type: int\n Applies to: compiler and linker + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_LTO = 20, + + /** + * \deprecated + * Control single-precision denormals (-ftz) support (0: false, default). + * 1 : flushes denormal values to zero + * 0 : preserves denormal values + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_FTZ = 21, + + /** + * \deprecated + * Control single-precision floating-point division and reciprocals + * (-prec-div) support (1: true, default). + * 1 : Enables the IEEE round-to-nearest mode + * 0 : Enables the fast approximation mode + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_PREC_DIV = 22, + + /** + * \deprecated + * Control single-precision floating-point square root + * (-prec-sqrt) support (1: true, default). + * 1 : Enables the IEEE round-to-nearest mode + * 0 : Enables the fast approximation mode + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_PREC_SQRT = 23, + + /** + * \deprecated + * Enable/Disable the contraction of floating-point multiplies + * and adds/subtracts into floating-point multiply-add (-fma) + * operations (1: Enable, default; 0: Disable). + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_FMA = 24, + + /** + * \deprecated + * Array of kernel names that should be preserved at link time while others + * can be removed.\n + * Must contain ::CU_JIT_REFERENCED_KERNEL_COUNT entries.\n + * Note that kernel names can be mangled by the compiler in which case the + * mangled name needs to be specified.\n + * Wildcard "*" can be used to represent zero or more characters instead of + * specifying the full or mangled name.\n + * It is important to note that the wildcard "*" is also added implicitly. + * For example, specifying "foo" will match "foobaz", "barfoo", "barfoobaz" + * and thus preserve all kernels with those names. This can be avoided by + * providing a more specific name like "barfoobaz".\n Option type: const char + * **\n Applies to: dynamic linker only + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_KERNEL_NAMES = 25, + + /** + * \deprecated + * Number of entries in ::CU_JIT_REFERENCED_KERNEL_NAMES array.\n + * Option type: unsigned int\n + * Applies to: dynamic linker only + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_KERNEL_COUNT = 26, + + /** + * \deprecated + * Array of variable names (__device__ and/or __constant__) that should be + * preserved at link time while others can be removed.\n + * Must contain ::CU_JIT_REFERENCED_VARIABLE_COUNT entries.\n + * Note that variable names can be mangled by the compiler in which case the + * mangled name needs to be specified.\n + * Wildcard "*" can be used to represent zero or more characters instead of + * specifying the full or mangled name.\n + * It is important to note that the wildcard "*" is also added implicitly. + * For example, specifying "foo" will match "foobaz", "barfoo", "barfoobaz" + * and thus preserve all variables with those names. This can be avoided by + * providing a more specific name like "barfoobaz".\n Option type: const char + * **\n Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_VARIABLE_NAMES = 27, + + /** + * \deprecated + * Number of entries in ::CU_JIT_REFERENCED_VARIABLE_NAMES array.\n + * Option type: unsigned int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_VARIABLE_COUNT = 28, + + /** + * \deprecated + * This option serves as a hint to enable the JIT compiler/linker + * to remove constant (__constant__) and device (__device__) variables + * unreferenced in device code (Disabled by default).\n + * Note that host references to constant and device variables using APIs like + * ::cuModuleGetGlobal() with this option specified may result in undefined + * behavior unless the variables are explicitly specified using + * ::CU_JIT_REFERENCED_VARIABLE_NAMES.\n Option type: int\n Applies to: + * link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES = 29, + + /** + * Generate position independent code (0: false)\n + * Option type: int\n + * Applies to: compiler only + */ + CU_JIT_POSITION_INDEPENDENT_CODE = 30, + + /** + * This option hints to the JIT compiler the minimum number of CTAs from the + * kernel’s grid to be mapped to a SM. This option is ignored when used + * together with ::CU_JIT_MAX_REGISTERS or ::CU_JIT_THREADS_PER_BLOCK. + * Optimizations based on this option need ::CU_JIT_MAX_THREADS_PER_BLOCK to + * be specified as well. For kernels already using PTX directive + * .minnctapersm, this option will be ignored by default. Use + * ::CU_JIT_OVERRIDE_DIRECTIVE_VALUES to let this option take precedence over + * the PTX directive. Option type: unsigned int\n Applies to: compiler only + */ + CU_JIT_MIN_CTA_PER_SM = 31, + + /** + * Maximum number threads in a thread block, computed as the product of + * the maximum extent specific for each dimension of the block. This limit + * is guaranteed not to be exceeded in any invocation of the kernel. Exceeding + * the the maximum number of threads results in runtime error or kernel launch + * failure. For kernels already using PTX directive .maxntid, this option will + * be ignored by default. Use ::CU_JIT_OVERRIDE_DIRECTIVE_VALUES to let this + * option take precedence over the PTX directive. + * Option type: int\n + * Applies to: compiler only + */ + CU_JIT_MAX_THREADS_PER_BLOCK = 32, + + /** + * This option lets the values specified using ::CU_JIT_MAX_REGISTERS, + * ::CU_JIT_THREADS_PER_BLOCK, ::CU_JIT_MAX_THREADS_PER_BLOCK and + * ::CU_JIT_MIN_CTA_PER_SM take precedence over any PTX directives. + * (0: Disable, default; 1: Enable) + * Option type: int\n + * Applies to: compiler only + */ + CU_JIT_OVERRIDE_DIRECTIVE_VALUES = 33, + CU_JIT_NUM_OPTIONS + +} CUjit_option; + +/* + * Indicates that compute device class supports accelerated features. + */ +#define CU_COMPUTE_ACCELERATED_TARGET_BASE 0x10000 + +/** + * Online compilation targets + */ +typedef enum CUjit_target_enum { + CU_TARGET_COMPUTE_30 = 30, /**< Compute device class 3.0 */ + CU_TARGET_COMPUTE_32 = 32, /**< Compute device class 3.2 */ + CU_TARGET_COMPUTE_35 = 35, /**< Compute device class 3.5 */ + CU_TARGET_COMPUTE_37 = 37, /**< Compute device class 3.7 */ + CU_TARGET_COMPUTE_50 = 50, /**< Compute device class 5.0 */ + CU_TARGET_COMPUTE_52 = 52, /**< Compute device class 5.2 */ + CU_TARGET_COMPUTE_53 = 53, /**< Compute device class 5.3 */ + CU_TARGET_COMPUTE_60 = 60, /**< Compute device class 6.0.*/ + CU_TARGET_COMPUTE_61 = 61, /**< Compute device class 6.1.*/ + CU_TARGET_COMPUTE_62 = 62, /**< Compute device class 6.2.*/ + CU_TARGET_COMPUTE_70 = 70, /**< Compute device class 7.0.*/ + CU_TARGET_COMPUTE_72 = 72, /**< Compute device class 7.2.*/ + CU_TARGET_COMPUTE_75 = 75, /**< Compute device class 7.5.*/ + CU_TARGET_COMPUTE_80 = 80, /**< Compute device class 8.0.*/ + CU_TARGET_COMPUTE_86 = 86, /**< Compute device class 8.6.*/ + CU_TARGET_COMPUTE_87 = 87, /**< Compute device class 8.7.*/ + CU_TARGET_COMPUTE_89 = 89, /**< Compute device class 8.9.*/ + CU_TARGET_COMPUTE_90 = 90, /**< Compute device class 9.0.*/ + + /**< Compute device class 9.0. with accelerated features.*/ + CU_TARGET_COMPUTE_90A = + CU_COMPUTE_ACCELERATED_TARGET_BASE + CU_TARGET_COMPUTE_90, +} CUjit_target; + +/** + * Cubin matching fallback strategies + */ +typedef enum CUjit_fallback_enum { + CU_PREFER_PTX = + 0, /**< Prefer to compile ptx if exact binary match not found */ + + CU_PREFER_BINARY /**< Prefer to fall back to compatible binary code if exact + match not found */ + +} CUjit_fallback; + +/** + * Caching modes for dlcm + */ +typedef enum CUjit_cacheMode_enum { + CU_JIT_CACHE_OPTION_NONE = 0, /**< Compile with no -dlcm flag specified */ + CU_JIT_CACHE_OPTION_CG, /**< Compile with L1 cache disabled */ + CU_JIT_CACHE_OPTION_CA /**< Compile with L1 cache enabled */ +} CUjit_cacheMode; + +/** + * Device code formats + */ +typedef enum CUjitInputType_enum { + /** + * Compiled device-class-specific device code\n + * Applicable options: none + */ + CU_JIT_INPUT_CUBIN = 0, + + /** + * PTX source code\n + * Applicable options: PTX compiler options + */ + CU_JIT_INPUT_PTX = 1, + + /** + * Bundle of multiple cubins and/or PTX of some device code\n + * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY + */ + CU_JIT_INPUT_FATBINARY = 2, + + /** + * Host object with embedded device code\n + * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY + */ + CU_JIT_INPUT_OBJECT = 3, + + /** + * Archive of host objects with embedded device code\n + * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY + */ + CU_JIT_INPUT_LIBRARY = 4, + + /** + * \deprecated + * High-level intermediate code for link-time optimization\n + * Applicable options: NVVM compiler options, PTX compiler options + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_INPUT_NVVM = 5, + + CU_JIT_NUM_INPUT_TYPES = 6 +} CUjitInputType; + +typedef struct CUlinkState_st *CUlinkState; + +/** + * Flags to register a graphics resource + */ +typedef enum CUgraphicsRegisterFlags_enum { + CU_GRAPHICS_REGISTER_FLAGS_NONE = 0x00, + CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY = 0x01, + CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD = 0x02, + CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST = 0x04, + CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER = 0x08 +} CUgraphicsRegisterFlags; + +/** + * Flags for mapping and unmapping interop resources + */ +typedef enum CUgraphicsMapResourceFlags_enum { + CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE = 0x00, + CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY = 0x01, + CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD = 0x02 +} CUgraphicsMapResourceFlags; + +/** + * Array indices for cube faces + */ +typedef enum CUarray_cubemap_face_enum { + CU_CUBEMAP_FACE_POSITIVE_X = 0x00, /**< Positive X face of cubemap */ + CU_CUBEMAP_FACE_NEGATIVE_X = 0x01, /**< Negative X face of cubemap */ + CU_CUBEMAP_FACE_POSITIVE_Y = 0x02, /**< Positive Y face of cubemap */ + CU_CUBEMAP_FACE_NEGATIVE_Y = 0x03, /**< Negative Y face of cubemap */ + CU_CUBEMAP_FACE_POSITIVE_Z = 0x04, /**< Positive Z face of cubemap */ + CU_CUBEMAP_FACE_NEGATIVE_Z = 0x05 /**< Negative Z face of cubemap */ +} CUarray_cubemap_face; + +/** + * Limits + */ +typedef enum CUlimit_enum { + CU_LIMIT_STACK_SIZE = 0x00, /**< GPU thread stack size */ + CU_LIMIT_PRINTF_FIFO_SIZE = 0x01, /**< GPU printf FIFO size */ + CU_LIMIT_MALLOC_HEAP_SIZE = 0x02, /**< GPU malloc heap size */ + CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH = + 0x03, /**< GPU device runtime launch synchronize depth */ + CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT = + 0x04, /**< GPU device runtime pending launch count */ + CU_LIMIT_MAX_L2_FETCH_GRANULARITY = + 0x05, /**< A value between 0 and 128 that indicates the maximum fetch + granularity of L2 (in Bytes). This is a hint */ + CU_LIMIT_PERSISTING_L2_CACHE_SIZE = + 0x06, /**< A size in bytes for L2 persisting lines cache size */ + CU_LIMIT_MAX +} CUlimit; + +/** + * Resource types + */ +typedef enum CUresourcetype_enum { + CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resource */ + CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */ + CU_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */ + CU_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */ +} CUresourcetype; + +#ifdef _WIN32 +#define CUDA_CB __stdcall +#else +#define CUDA_CB +#endif + +/** + * CUDA host function + * \param userData Argument value passed to the function + */ +typedef void(CUDA_CB *CUhostFn)(void *userData); + +/** + * Specifies performance hint with ::CUaccessPolicyWindow for hitProp and + * missProp members. + */ +typedef enum CUaccessProperty_enum { + CU_ACCESS_PROPERTY_NORMAL = 0, /**< Normal cache persistence. */ + CU_ACCESS_PROPERTY_STREAMING = + 1, /**< Streaming access is less likely to persit from cache. */ + CU_ACCESS_PROPERTY_PERSISTING = + 2 /**< Persisting access is more likely to persist in cache.*/ +} CUaccessProperty; + +/** + * Specifies an access policy for a window, a contiguous extent of memory + * beginning at base_ptr and ending at base_ptr + num_bytes. + * num_bytes is limited by CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE. + * Partition into many segments and assign segments such that: + * sum of "hit segments" / window == approx. ratio. + * sum of "miss segments" / window == approx 1-ratio. + * Segments and ratio specifications are fitted to the capabilities of + * the architecture. + * Accesses in a hit segment apply the hitProp access policy. + * Accesses in a miss segment apply the missProp access policy. + */ +typedef struct CUaccessPolicyWindow_st { + void *base_ptr; /**< Starting address of the access policy window. CUDA driver + may align it. */ + size_t num_bytes; /**< Size in bytes of the window policy. CUDA driver may + restrict the maximum size and alignment. */ + float hitRatio; /**< hitRatio specifies percentage of lines assigned hitProp, + rest are assigned missProp. */ + CUaccessProperty hitProp; /**< ::CUaccessProperty set for hit. */ + CUaccessProperty missProp; /**< ::CUaccessProperty set for miss. Must be + either NORMAL or STREAMING */ +} CUaccessPolicyWindow_v1; +/** + * Access policy window + */ +typedef CUaccessPolicyWindow_v1 CUaccessPolicyWindow; + +/** + * GPU kernel node parameters + */ +typedef struct CUDA_KERNEL_NODE_PARAMS_st { + CUfunction func; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block + in bytes */ + void **kernelParams; /**< Array of pointers to kernel parameters */ + void **extra; /**< Extra options */ +} CUDA_KERNEL_NODE_PARAMS_v1; + +/** + * GPU kernel node parameters + */ +typedef struct CUDA_KERNEL_NODE_PARAMS_v2_st { + CUfunction func; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block + in bytes */ + void **kernelParams; /**< Array of pointers to kernel parameters */ + void **extra; /**< Extra options */ + CUkernel + kern; /**< Kernel to launch, will only be referenced if func is NULL */ + CUcontext ctx; /**< Context for the kernel task to run in. The value NULL will + indicate the current context should be used by the api. This + field is ignored if func is set. */ +} CUDA_KERNEL_NODE_PARAMS_v2; +typedef CUDA_KERNEL_NODE_PARAMS_v2 CUDA_KERNEL_NODE_PARAMS; + +/** + * GPU kernel node parameters + */ +typedef struct CUDA_KERNEL_NODE_PARAMS_v3_st { + CUfunction func; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block + in bytes */ + void **kernelParams; /**< Array of pointers to kernel parameters */ + void **extra; /**< Extra options */ + CUkernel + kern; /**< Kernel to launch, will only be referenced if func is NULL */ + CUcontext ctx; /**< Context for the kernel task to run in. The value NULL will + indicate the current context should be used by the api. This + field is ignored if func is set. */ +} CUDA_KERNEL_NODE_PARAMS_v3; + +/** + * Memset node parameters + */ +typedef struct CUDA_MEMSET_NODE_PARAMS_st { + CUdeviceptr dst; /**< Destination device pointer */ + size_t + pitch; /**< Pitch of destination device pointer. Unused if height is 1 */ + unsigned int value; /**< Value to be set */ + unsigned int + elementSize; /**< Size of each element in bytes. Must be 1, 2, or 4. */ + size_t width; /**< Width of the row in elements */ + size_t height; /**< Number of rows */ +} CUDA_MEMSET_NODE_PARAMS_v1; +typedef CUDA_MEMSET_NODE_PARAMS_v1 CUDA_MEMSET_NODE_PARAMS; + +/** + * Memset node parameters + */ +typedef struct CUDA_MEMSET_NODE_PARAMS_v2_st { + CUdeviceptr dst; /**< Destination device pointer */ + size_t + pitch; /**< Pitch of destination device pointer. Unused if height is 1 */ + unsigned int value; /**< Value to be set */ + unsigned int + elementSize; /**< Size of each element in bytes. Must be 1, 2, or 4. */ + size_t width; /**< Width of the row in elements */ + size_t height; /**< Number of rows */ + CUcontext ctx; /**< Context on which to run the node */ +} CUDA_MEMSET_NODE_PARAMS_v2; + +/** + * Host node parameters + */ +typedef struct CUDA_HOST_NODE_PARAMS_st { + CUhostFn fn; /**< The function to call when the node executes */ + void *userData; /**< Argument to pass to the function */ +} CUDA_HOST_NODE_PARAMS_v1; +typedef CUDA_HOST_NODE_PARAMS_v1 CUDA_HOST_NODE_PARAMS; + +/** + * Host node parameters + */ +typedef struct CUDA_HOST_NODE_PARAMS_v2_st { + CUhostFn fn; /**< The function to call when the node executes */ + void *userData; /**< Argument to pass to the function */ +} CUDA_HOST_NODE_PARAMS_v2; + +/** + * Conditional node handle flags + */ +#define CU_GRAPH_COND_ASSIGN_DEFAULT \ + 0x1 /**< Default value is applied when graph is launched. */ + +/** + * Conditional node types + */ +typedef enum CUgraphConditionalNodeType_enum { + CU_GRAPH_COND_TYPE_IF = 0, /**< Conditional 'if' Node. Body executed once if + condition value is non-zero. */ + CU_GRAPH_COND_TYPE_WHILE = + 1, /**< Conditional 'while' Node. Body executed repeatedly while condition + value is non-zero. */ +} CUgraphConditionalNodeType; + +/** + * Conditional node parameters + */ +typedef struct CUDA_CONDITIONAL_NODE_PARAMS { + CUgraphConditionalHandle + handle; /**< Conditional node handle. + Handles must be created in advance of creating the node + using ::cuGraphConditionalHandleCreate. */ + CUgraphConditionalNodeType type; /**< Type of conditional node. */ + unsigned int size; /**< Size of graph output array. Must be 1. */ + CUgraph + *phGraph_out; /**< CUDA-owned array populated with conditional node child + graphs during creation of the node. Valid for the + lifetime of the conditional node. The contents of the + graph(s) are subject to the following constraints: + + - Allowed node types are kernel nodes, empty nodes, + child graphs, memsets, memcopies, and conditionals. This + applies recursively to child graphs and conditional + bodies. + - All kernels, including kernels in nested conditionals + or child graphs at any level, must belong to the same + CUDA context. + + These graphs may be populated using graph node creation + APIs or ::cuStreamBeginCaptureToGraph. */ + CUcontext ctx; /**< Context on which to run the node. Must match context used + to create the handle and all body nodes. */ +} CUDA_CONDITIONAL_NODE_PARAMS; + +/** + * Graph node types + */ +typedef enum CUgraphNodeType_enum { + CU_GRAPH_NODE_TYPE_KERNEL = 0, /**< GPU kernel node */ + CU_GRAPH_NODE_TYPE_MEMCPY = 1, /**< Memcpy node */ + CU_GRAPH_NODE_TYPE_MEMSET = 2, /**< Memset node */ + CU_GRAPH_NODE_TYPE_HOST = 3, /**< Host (executable) node */ + CU_GRAPH_NODE_TYPE_GRAPH = 4, /**< Node which executes an embedded graph */ + CU_GRAPH_NODE_TYPE_EMPTY = 5, /**< Empty (no-op) node */ + CU_GRAPH_NODE_TYPE_WAIT_EVENT = 6, /**< External event wait node */ + CU_GRAPH_NODE_TYPE_EVENT_RECORD = 7, /**< External event record node */ + CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL = + 8, /**< External semaphore signal node */ + CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT = 9, /**< External semaphore wait node */ + CU_GRAPH_NODE_TYPE_MEM_ALLOC = 10, /**< Memory Allocation Node */ + CU_GRAPH_NODE_TYPE_MEM_FREE = 11, /**< Memory Free Node */ + CU_GRAPH_NODE_TYPE_BATCH_MEM_OP = 12 /**< Batch MemOp Node */ + , + CU_GRAPH_NODE_TYPE_CONDITIONAL = + 13 /**< Conditional Node + + May be used to implement a conditional execution path or loop + inside of a graph. The graph(s) contained within the body of the + conditional node can be selectively executed or iterated upon based + on the value of a conditional variable. + + Handles must be created in advance of creating the node + using ::cuGraphConditionalHandleCreate. + + The following restrictions apply to graphs which contain + conditional nodes: The graph cannot be used in a child node. Only + one instantiation of the graph may exist at any point in time. The + graph cannot be cloned. + + To set the control value, supply a default value when creating the + handle and/or call ::cudaGraphSetConditional from device code.*/ +} CUgraphNodeType; + +/** + * Type annotations that can be applied to graph edges as part of + * ::CUgraphEdgeData. + */ +typedef enum CUgraphDependencyType_enum { + CU_GRAPH_DEPENDENCY_TYPE_DEFAULT = 0, /**< This is an ordinary dependency. */ + CU_GRAPH_DEPENDENCY_TYPE_PROGRAMMATIC = + 1 /**< This dependency type allows the downstream node to + use \c cudaGridDependencySynchronize(). It may only be used + between kernel nodes, and must be used with either the + ::CU_GRAPH_KERNEL_NODE_PORT_PROGRAMMATIC or + ::CU_GRAPH_KERNEL_NODE_PORT_LAUNCH_ORDER outgoing port. */ +} CUgraphDependencyType; + +/** + * This port activates when the kernel has finished executing. + */ +#define CU_GRAPH_KERNEL_NODE_PORT_DEFAULT 0 +/** + * This port activates when all blocks of the kernel have performed + * cudaTriggerProgrammaticLaunchCompletion() or have terminated. It must be used + * with edge type ::CU_GRAPH_DEPENDENCY_TYPE_PROGRAMMATIC. See also + * ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT. + */ +#define CU_GRAPH_KERNEL_NODE_PORT_PROGRAMMATIC 1 +/** + * This port activates when all blocks of the kernel have begun execution. See + * also + * ::CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT. + */ +#define CU_GRAPH_KERNEL_NODE_PORT_LAUNCH_ORDER 2 + +/** + * Optional annotation for edges in a CUDA graph. Note, all edges implicitly + * have annotations and default to a zero-initialized value if not specified. A + * zero-initialized struct indicates a standard full serialization of two nodes + * with memory visibility. + */ +typedef struct CUgraphEdgeData_st { + unsigned char + from_port; /**< This indicates when the dependency is triggered from the + upstream node on the edge. The meaning is specific to the + node type. A value of 0 in all cases means full completion + of the upstream node, with memory visibility to the + downstream node or portion thereof (indicated by \c + to_port).
Only kernel nodes define non-zero ports. A + kernel node can use the following output port types: + ::CU_GRAPH_KERNEL_NODE_PORT_DEFAULT, + ::CU_GRAPH_KERNEL_NODE_PORT_PROGRAMMATIC, or + ::CU_GRAPH_KERNEL_NODE_PORT_LAUNCH_ORDER. */ + unsigned char + to_port; /**< This indicates what portion of the downstream node is + dependent on the upstream node or portion thereof + (indicated by \c from_port). The meaning is + specific to the node type. A value of 0 in all + cases means the entirety of the + downstream node is dependent on the + upstream work.
Currently no + node types define non-zero ports. Accordingly, + this field must be set to zero. */ + unsigned char type; /**< This should be populated with a value from + ::CUgraphDependencyType. (It is typed as char due to + compiler-specific layout of bitfields.) See + ::CUgraphDependencyType. */ + unsigned char reserved[5]; /**< These bytes are unused and must be zeroed. + This ensures compatibility if additional fields + are added in the future. */ +} CUgraphEdgeData; + +/** + * Graph instantiation results + */ +typedef enum CUgraphInstantiateResult_enum { + CUDA_GRAPH_INSTANTIATE_SUCCESS = 0, /**< Instantiation succeeded */ + CUDA_GRAPH_INSTANTIATE_ERROR = + 1, /**< Instantiation failed for an unexpected reason which is described + in the return value of the function */ + CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE = + 2, /**< Instantiation failed due to invalid structure, such as cycles */ + CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED = + 3, /**< Instantiation for device launch failed because the graph contained + an unsupported operation */ + CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED = + 4 /**< Instantiation for device launch failed due to the nodes belonging + to different contexts */ +} CUgraphInstantiateResult; + +/** + * Graph instantiation parameters + */ +typedef struct CUDA_GRAPH_INSTANTIATE_PARAMS_st { + cuuint64_t flags; /**< Instantiation flags */ + CUstream hUploadStream; /**< Upload stream */ + CUgraphNode + hErrNode_out; /**< The node which caused instantiation to fail, if any */ + CUgraphInstantiateResult + result_out; /**< Whether instantiation was successful. If it failed, the + reason why */ +} CUDA_GRAPH_INSTANTIATE_PARAMS; + +typedef enum CUsynchronizationPolicy_enum { + CU_SYNC_POLICY_AUTO = 1, + CU_SYNC_POLICY_SPIN = 2, + CU_SYNC_POLICY_YIELD = 3, + CU_SYNC_POLICY_BLOCKING_SYNC = 4 +} CUsynchronizationPolicy; + +/** + * Cluster scheduling policies. These may be passed to ::cuFuncSetAttribute or + * ::cuKernelSetAttribute + */ +typedef enum CUclusterSchedulingPolicy_enum { + CU_CLUSTER_SCHEDULING_POLICY_DEFAULT = 0, /**< the default policy */ + CU_CLUSTER_SCHEDULING_POLICY_SPREAD = + 1, /**< spread the blocks within a cluster to the SMs */ + CU_CLUSTER_SCHEDULING_POLICY_LOAD_BALANCING = + 2 /**< allow the hardware to load-balance the blocks in a cluster to the + SMs */ +} CUclusterSchedulingPolicy; + +/** + * Memory Synchronization Domain + * + * A kernel can be launched in a specified memory synchronization domain that + * affects all memory operations issued by that kernel. A memory barrier issued + * in one domain will only order memory operations in that domain, thus + * eliminating latency increase from memory barriers ordering unrelated traffic. + * + * By default, kernels are launched in domain 0. Kernel launched with + * ::CU_LAUNCH_MEM_SYNC_DOMAIN_REMOTE will have a different domain ID. User may + * also alter the domain ID with ::CUlaunchMemSyncDomainMap for a specific + * stream / graph node / kernel launch. See + * ::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN, ::cuStreamSetAttribute, + * ::cuLaunchKernelEx, + * ::cuGraphKernelNodeSetAttribute. + * + * Memory operations done in kernels launched in different domains are + * considered system-scope distanced. In other words, a GPU scoped memory + * synchronization is not sufficient for memory order to be observed by kernels + * in another memory synchronization domain even if they are on the same GPU. + */ +typedef enum CUlaunchMemSyncDomain_enum { + CU_LAUNCH_MEM_SYNC_DOMAIN_DEFAULT = + 0, /**< Launch kernels in the default domain */ + CU_LAUNCH_MEM_SYNC_DOMAIN_REMOTE = + 1 /**< Launch kernels in the remote domain */ +} CUlaunchMemSyncDomain; + +/** + * Memory Synchronization Domain map + * + * See ::cudaLaunchMemSyncDomain. + * + * By default, kernels are launched in domain 0. Kernel launched with + * ::CU_LAUNCH_MEM_SYNC_DOMAIN_REMOTE will have a different domain ID. User may + * also alter the domain ID with ::CUlaunchMemSyncDomainMap for a specific + * stream / graph node / kernel launch. See + * ::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP. + * + * Domain ID range is available through + * ::CU_DEVICE_ATTRIBUTE_MEM_SYNC_DOMAIN_COUNT. + */ +typedef struct CUlaunchMemSyncDomainMap_st { + unsigned char + default_; /**< The default domain ID to use for designated kernels */ + unsigned char + remote; /**< The remote domain ID to use for designated kernels */ +} CUlaunchMemSyncDomainMap; + +/** + * Launch attributes enum; used as id field of ::CUlaunchAttribute + */ +typedef enum CUlaunchAttributeID_enum { + CU_LAUNCH_ATTRIBUTE_IGNORE = + 0 /**< Ignored entry, for convenient composition */ + , + CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW = + 1 /**< Valid for streams, graph nodes, launches. See + ::CUlaunchAttributeValue::accessPolicyWindow. */ + , + CU_LAUNCH_ATTRIBUTE_COOPERATIVE = + 2 /**< Valid for graph nodes, launches. See + ::CUlaunchAttributeValue::cooperative. */ + , + CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY = + 3 /**< Valid for streams. See + ::CUlaunchAttributeValue::syncPolicy. */ + , + CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION = + 4 /**< Valid for graph nodes, launches. See + ::CUlaunchAttributeValue::clusterDim. */ + , + CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = + 5 /**< Valid for graph nodes, launches. See + ::CUlaunchAttributeValue::clusterSchedulingPolicyPreference. */ + , + CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION = + 6 /**< Valid for launches. Setting + ::CUlaunchAttributeValue::programmaticStreamSerializationAllowed + to non-0 signals that the kernel will use programmatic + means to resolve its stream dependency, so that the + CUDA runtime should opportunistically allow the grid's + execution to overlap with the previous kernel in the + stream, if that kernel requests the overlap. The + dependent launches can choose to wait on the + dependency using the programmatic sync + (cudaGridDependencySynchronize() or equivalent PTX + instructions). */ + , + CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT = + 7 /**< Valid for launches. Set + ::CUlaunchAttributeValue::programmaticEvent to + record the event. Event recorded through this + launch attribute is guaranteed to only trigger + after all block in the associated kernel trigger + the event. A block can trigger the event through + PTX launchdep.release or CUDA builtin function + cudaTriggerProgrammaticLaunchCompletion(). A + trigger can also be inserted at the beginning of + each block's execution if triggerAtBlockStart is + set to non-0. The dependent launches can choose to + wait on the dependency using the programmatic sync + (cudaGridDependencySynchronize() or equivalent PTX + instructions). Note that dependents (including the + CPU thread calling cuEventSynchronize()) are not + guaranteed to observe the release precisely when + it is released. For example, cuEventSynchronize() + may only observe the event trigger long after the + associated kernel has completed. This recording + type is primarily meant for establishing + programmatic dependency between device tasks. Note + also this type of dependency allows, but does not + guarantee, concurrent execution of tasks. +
+ The event supplied must not be an interprocess or + interop event. The event must disable timing (i.e. + must be created with the ::CU_EVENT_DISABLE_TIMING + flag set). + */ + , + CU_LAUNCH_ATTRIBUTE_PRIORITY = + 8 /**< Valid for streams, graph nodes, launches. See + ::CUlaunchAttributeValue::priority. */ + , + CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP = + 9 /**< Valid for streams, graph nodes, launches. See + ::CUlaunchAttributeValue::memSyncDomainMap. */ + , + CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN = + 10 /**< Valid for streams, graph nodes, launches. See + ::CUlaunchAttributeValue::memSyncDomain. */ + , + CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT = + 12 /**< Valid for launches. Set + ::CUlaunchAttributeValue::launchCompletionEvent to record the + event. +
+ Nominally, the event is triggered once all blocks of the kernel + have begun execution. Currently this is a best effort. If a kernel + B has a launch completion dependency on a kernel A, B may wait + until A is complete. Alternatively, blocks of B may begin before + all blocks of A have begun, for example if B can claim execution + resources unavailable to A (e.g. they run on different GPUs) or + if B is a higher priority than A. + Exercise caution if such an ordering inversion could lead + to deadlock. +
+ A launch completion event is nominally similar to a programmatic + event with \c triggerAtBlockStart set except that it is not + visible to \c cudaGridDependencySynchronize() and can be used with + compute capability less than 9.0. +
+ The event supplied must not be an interprocess or interop + event. The event must disable timing (i.e. must be created + with the ::CU_EVENT_DISABLE_TIMING flag set). */ + , + CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE = + 13 /**< Valid for graph nodes, launches. This attribute is graphs-only, + and passing it to a launch in a non-capturing stream will result + in an error. +
+ ::CUlaunchAttributeValue::deviceUpdatableKernelNode::deviceUpdatable + can only be set to 0 or 1. Setting the field to 1 indicates that the + corresponding kernel node should be device-updatable. On success, + a handle will be returned via + ::CUlaunchAttributeValue::deviceUpdatableKernelNode::devNode which + can be passed to the various device-side update functions to update + the node's kernel parameters from within another kernel. For more + information on the types of device updates that can be made, as well + as the relevant limitations thereof, see + ::cudaGraphKernelNodeUpdatesApply.
Nodes which are + device-updatable have additional restrictions compared to regular + kernel nodes. Firstly, device-updatable nodes cannot be removed from + their graph via ::cuGraphDestroyNode. Additionally, once opted-in to + this functionality, a node cannot opt out, and any attempt to set + the deviceUpdatable attribute to 0 will result in an error. + Device-updatable kernel nodes also cannot have their attributes + copied to/from another kernel node via + ::cuGraphKernelNodeCopyAttributes. Graphs containing one or more + device-updatable nodes also do not allow multiple instantiation, + and neither the graph nor its instantiated version can be passed to + ::cuGraphExecUpdate.
If a graph contains device-updatable nodes + and updates those nodes from the device from within the graph, the + graph must be uploaded with ::cuGraphUpload before it is launched. + For such a graph, if host-side executable graph updates are made to + the device-updatable nodes, the graph must be uploaded before it is + launched again. */ +#ifdef __CUDA_API_VERSION_INTERNAL + , + CU_LAUNCH_ATTRIBUTE_MAX +#endif +} CUlaunchAttributeID; + +/** + * Launch attributes union; used as value field of ::CUlaunchAttribute + */ +typedef union CUlaunchAttributeValue_union { + char pad[64]; /* Pad to 64 bytes */ + CUaccessPolicyWindow + accessPolicyWindow; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW. */ + int cooperative; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_COOPERATIVE. Nonzero indicates a + cooperative kernel (see + ::cuLaunchCooperativeKernel). */ + CUsynchronizationPolicy + syncPolicy; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY. + ::CUsynchronizationPolicy for work + queued up in this stream */ + + /** + * Value of launch attribute ::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION that + * represents the desired cluster dimensions for the kernel. Opaque type + * with the following fields: + * - \p x - The X dimension of the cluster, in blocks. Must be a divisor + * of the grid X dimension. + * - \p y - The Y dimension of the cluster, in blocks. Must be a divisor + * of the grid Y dimension. + * - \p z - The Z dimension of the cluster, in blocks. Must be a divisor + * of the grid Z dimension. + */ + struct { + unsigned int x; + unsigned int y; + unsigned int z; + } clusterDim; + CUclusterSchedulingPolicy + clusterSchedulingPolicyPreference; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE. + Cluster scheduling policy + preference for the kernel. */ + int programmaticStreamSerializationAllowed; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION. + */ + struct { + CUevent event; /**< Event to fire when all blocks trigger it */ + int flags; /**< Event record flags, see ::cuEventRecordWithFlags. Does not + accept + ::CU_EVENT_RECORD_EXTERNAL. */ + int triggerAtBlockStart; /**< If this is set to non-0, each block launch + will automatically trigger the event */ + } programmaticEvent; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT. */ + struct { + CUevent event; /**< Event to fire when the last block launches */ + int flags; /**< Event record flags, see ::cuEventRecordWithFlags. Does not + accept ::CU_EVENT_RECORD_EXTERNAL. */ + } launchCompletionEvent; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT. */ + int priority; /**< Value of launch attribute ::CU_LAUNCH_ATTRIBUTE_PRIORITY. + Execution priority of the kernel. */ + CUlaunchMemSyncDomainMap + memSyncDomainMap; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP. See + ::CUlaunchMemSyncDomainMap. */ + CUlaunchMemSyncDomain memSyncDomain; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN. + See::CUlaunchMemSyncDomain */ + + struct { + int deviceUpdatable; /**< Whether or not the resulting kernel node should be + device-updatable. */ + CUgraphDeviceNode devNode; /**< Returns a handle to pass to the various + device-side update functions. */ + } deviceUpdatableKernelNode; /**< Value of launch attribute + ::CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE. + */ +} CUlaunchAttributeValue; + +/** + * Launch attribute + */ +typedef struct CUlaunchAttribute_st { + CUlaunchAttributeID id; /**< Attribute to set */ + char pad[8 - sizeof(CUlaunchAttributeID)]; + CUlaunchAttributeValue value; /**< Value of the attribute */ +} CUlaunchAttribute; + +/** + * CUDA extensible launch configuration + */ +typedef struct CUlaunchConfig_st { + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block + in bytes */ + CUstream hStream; /**< Stream identifier */ + CUlaunchAttribute *attrs; /**< List of attributes; nullable if + ::CUlaunchConfig::numAttrs == 0 */ + unsigned int numAttrs; /**< Number of attributes populated in + ::CUlaunchConfig::attrs */ +} CUlaunchConfig; + +typedef CUlaunchAttributeID CUkernelNodeAttrID; +#define CU_KERNEL_NODE_ATTRIBUTE_ACCESS_POLICY_WINDOW \ + CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW +#define CU_KERNEL_NODE_ATTRIBUTE_COOPERATIVE CU_LAUNCH_ATTRIBUTE_COOPERATIVE +#define CU_KERNEL_NODE_ATTRIBUTE_CLUSTER_DIMENSION \ + CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION +#define CU_KERNEL_NODE_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE \ + CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE +#define CU_KERNEL_NODE_ATTRIBUTE_PRIORITY CU_LAUNCH_ATTRIBUTE_PRIORITY +#define CU_KERNEL_NODE_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP \ + CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP +#define CU_KERNEL_NODE_ATTRIBUTE_MEM_SYNC_DOMAIN \ + CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN +#define CU_KERNEL_NODE_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE \ + CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE + +typedef CUlaunchAttributeValue CUkernelNodeAttrValue_v1; +typedef CUkernelNodeAttrValue_v1 CUkernelNodeAttrValue; + +/** + * Possible stream capture statuses returned by ::cuStreamIsCapturing + */ +typedef enum CUstreamCaptureStatus_enum { + CU_STREAM_CAPTURE_STATUS_NONE = 0, /**< Stream is not capturing */ + CU_STREAM_CAPTURE_STATUS_ACTIVE = 1, /**< Stream is actively capturing */ + CU_STREAM_CAPTURE_STATUS_INVALIDATED = + 2 /**< Stream is part of a capture sequence that + has been invalidated, but not terminated */ +} CUstreamCaptureStatus; + +/** + * Possible modes for stream capture thread interactions. For more details see + * ::cuStreamBeginCapture and ::cuThreadExchangeStreamCaptureMode + */ +typedef enum CUstreamCaptureMode_enum { + CU_STREAM_CAPTURE_MODE_GLOBAL = 0, + CU_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1, + CU_STREAM_CAPTURE_MODE_RELAXED = 2 +} CUstreamCaptureMode; + +typedef CUlaunchAttributeID CUstreamAttrID; +#define CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW \ + CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW +#define CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY \ + CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY +#define CU_STREAM_ATTRIBUTE_PRIORITY CU_LAUNCH_ATTRIBUTE_PRIORITY +#define CU_STREAM_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP \ + CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP +#define CU_STREAM_ATTRIBUTE_MEM_SYNC_DOMAIN CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN + +typedef CUlaunchAttributeValue CUstreamAttrValue_v1; +typedef CUstreamAttrValue_v1 CUstreamAttrValue; + +/** + * Flags to specify search options. For more details see ::cuGetProcAddress + */ +typedef enum CUdriverProcAddress_flags_enum { + CU_GET_PROC_ADDRESS_DEFAULT = + 0, /**< Default search mode for driver symbols. */ + CU_GET_PROC_ADDRESS_LEGACY_STREAM = + 1 << 0, /**< Search for legacy versions of driver symbols. */ + CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM = + 1 << 1 /**< Search for per-thread versions of driver symbols. */ +} CUdriverProcAddress_flags; + +/** + * Flags to indicate search status. For more details see ::cuGetProcAddress + */ +typedef enum CUdriverProcAddressQueryResult_enum { + CU_GET_PROC_ADDRESS_SUCCESS = 0, /**< Symbol was successfully found */ + CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = + 1, /**< Symbol was not found in search */ + CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = + 2 /**< Symbol was found but version supplied was not sufficient */ +} CUdriverProcAddressQueryResult; + +/** + * Execution Affinity Types + */ +typedef enum CUexecAffinityType_enum { + CU_EXEC_AFFINITY_TYPE_SM_COUNT = 0, /**< Create a context with limited SMs. */ + CU_EXEC_AFFINITY_TYPE_MAX +} CUexecAffinityType; + +/** + * Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT + */ +typedef struct CUexecAffinitySmCount_st { + unsigned int val; /**< The number of SMs the context is limited to use. */ +} CUexecAffinitySmCount_v1; +typedef CUexecAffinitySmCount_v1 CUexecAffinitySmCount; + +/** + * Execution Affinity Parameters + */ +typedef struct CUexecAffinityParam_st { + CUexecAffinityType type; + union { + CUexecAffinitySmCount + smCount; /** Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT */ + } param; +} CUexecAffinityParam_v1; +/** + * Execution Affinity Parameters + */ +typedef CUexecAffinityParam_v1 CUexecAffinityParam; + +/** + * Library options to be specified with ::cuLibraryLoadData() or + * ::cuLibraryLoadFromFile() + */ +typedef enum CUlibraryOption_enum { + CU_LIBRARY_HOST_UNIVERSAL_FUNCTION_AND_DATA_TABLE = 0, + + /** + * Specifies that the argument \p code passed to ::cuLibraryLoadData() will be + * preserved. Specifying this option will let the driver know that \p code can + * be accessed at any point until ::cuLibraryUnload(). The default behavior is + * for the driver to allocate and maintain its own copy of \p code. Note that + * this is only a memory usage optimization hint and the driver can choose to + * ignore it if required. Specifying this option with + * ::cuLibraryLoadFromFile() is invalid and will return + * ::CUDA_ERROR_INVALID_VALUE. + */ + CU_LIBRARY_BINARY_IS_PRESERVED = 1, + + CU_LIBRARY_NUM_OPTIONS +} CUlibraryOption; + +typedef struct CUlibraryHostUniversalFunctionAndDataTable_st { + void *functionTable; + size_t functionWindowSize; + void *dataTable; + size_t dataWindowSize; +} CUlibraryHostUniversalFunctionAndDataTable; + +/** + * Error codes + */ +typedef enum cudaError_enum { + /** + * The API call returned with no errors. In the case of query calls, this + * also means that the operation being queried is complete (see + * ::cuEventQuery() and ::cuStreamQuery()). + */ + CUDA_SUCCESS = 0, + + /** + * This indicates that one or more of the parameters passed to the API call + * is not within an acceptable range of values. + */ + CUDA_ERROR_INVALID_VALUE = 1, + + /** + * The API call failed because it was unable to allocate enough memory or + * other resources to perform the requested operation. + */ + CUDA_ERROR_OUT_OF_MEMORY = 2, + + /** + * This indicates that the CUDA driver has not been initialized with + * ::cuInit() or that initialization has failed. + */ + CUDA_ERROR_NOT_INITIALIZED = 3, + + /** + * This indicates that the CUDA driver is in the process of shutting down. + */ + CUDA_ERROR_DEINITIALIZED = 4, + + /** + * This indicates profiler is not initialized for this run. This can + * happen when the application is running with external profiling tools + * like visual profiler. + */ + CUDA_ERROR_PROFILER_DISABLED = 5, + + /** + * \deprecated + * This error return is deprecated as of CUDA 5.0. It is no longer an error + * to attempt to enable/disable the profiling via ::cuProfilerStart or + * ::cuProfilerStop without initialization. + */ + CUDA_ERROR_PROFILER_NOT_INITIALIZED = 6, + + /** + * \deprecated + * This error return is deprecated as of CUDA 5.0. It is no longer an error + * to call cuProfilerStart() when profiling is already enabled. + */ + CUDA_ERROR_PROFILER_ALREADY_STARTED = 7, + + /** + * \deprecated + * This error return is deprecated as of CUDA 5.0. It is no longer an error + * to call cuProfilerStop() when profiling is already disabled. + */ + CUDA_ERROR_PROFILER_ALREADY_STOPPED = 8, + + /** + * This indicates that the CUDA driver that the application has loaded is a + * stub library. Applications that run with the stub rather than a real + * driver loaded will result in CUDA API returning this error. + */ + CUDA_ERROR_STUB_LIBRARY = 34, + + /** + * This indicates that requested CUDA device is unavailable at the current + * time. Devices are often unavailable due to use of + * ::CU_COMPUTEMODE_EXCLUSIVE_PROCESS or ::CU_COMPUTEMODE_PROHIBITED. + */ + CUDA_ERROR_DEVICE_UNAVAILABLE = 46, + + /** + * This indicates that no CUDA-capable devices were detected by the installed + * CUDA driver. + */ + CUDA_ERROR_NO_DEVICE = 100, + + /** + * This indicates that the device ordinal supplied by the user does not + * correspond to a valid CUDA device or that the action requested is + * invalid for the specified device. + */ + CUDA_ERROR_INVALID_DEVICE = 101, + + /** + * This error indicates that the Grid license is not applied. + */ + CUDA_ERROR_DEVICE_NOT_LICENSED = 102, + + /** + * This indicates that the device kernel image is invalid. This can also + * indicate an invalid CUDA module. + */ + CUDA_ERROR_INVALID_IMAGE = 200, + + /** + * This most frequently indicates that there is no context bound to the + * current thread. This can also be returned if the context passed to an + * API call is not a valid handle (such as a context that has had + * ::cuCtxDestroy() invoked on it). This can also be returned if a user + * mixes different API versions (i.e. 3010 context with 3020 API calls). + * See ::cuCtxGetApiVersion() for more details. + * This can also be returned if the green context passed to an API call + * was not converted to a ::CUcontext using ::cuCtxFromGreenCtx API. + */ + CUDA_ERROR_INVALID_CONTEXT = 201, + + /** + * This indicated that the context being supplied as a parameter to the + * API call was already the active context. + * \deprecated + * This error return is deprecated as of CUDA 3.2. It is no longer an + * error to attempt to push the active context via ::cuCtxPushCurrent(). + */ + CUDA_ERROR_CONTEXT_ALREADY_CURRENT = 202, + + /** + * This indicates that a map or register operation has failed. + */ + CUDA_ERROR_MAP_FAILED = 205, + + /** + * This indicates that an unmap or unregister operation has failed. + */ + CUDA_ERROR_UNMAP_FAILED = 206, + + /** + * This indicates that the specified array is currently mapped and thus + * cannot be destroyed. + */ + CUDA_ERROR_ARRAY_IS_MAPPED = 207, + + /** + * This indicates that the resource is already mapped. + */ + CUDA_ERROR_ALREADY_MAPPED = 208, + + /** + * This indicates that there is no kernel image available that is suitable + * for the device. This can occur when a user specifies code generation + * options for a particular CUDA source file that do not include the + * corresponding device configuration. + */ + CUDA_ERROR_NO_BINARY_FOR_GPU = 209, + + /** + * This indicates that a resource has already been acquired. + */ + CUDA_ERROR_ALREADY_ACQUIRED = 210, + + /** + * This indicates that a resource is not mapped. + */ + CUDA_ERROR_NOT_MAPPED = 211, + + /** + * This indicates that a mapped resource is not available for access as an + * array. + */ + CUDA_ERROR_NOT_MAPPED_AS_ARRAY = 212, + + /** + * This indicates that a mapped resource is not available for access as a + * pointer. + */ + CUDA_ERROR_NOT_MAPPED_AS_POINTER = 213, + + /** + * This indicates that an uncorrectable ECC error was detected during + * execution. + */ + CUDA_ERROR_ECC_UNCORRECTABLE = 214, + + /** + * This indicates that the ::CUlimit passed to the API call is not + * supported by the active device. + */ + CUDA_ERROR_UNSUPPORTED_LIMIT = 215, + + /** + * This indicates that the ::CUcontext passed to the API call can + * only be bound to a single CPU thread at a time but is already + * bound to a CPU thread. + */ + CUDA_ERROR_CONTEXT_ALREADY_IN_USE = 216, + + /** + * This indicates that peer access is not supported across the given + * devices. + */ + CUDA_ERROR_PEER_ACCESS_UNSUPPORTED = 217, + + /** + * This indicates that a PTX JIT compilation failed. + */ + CUDA_ERROR_INVALID_PTX = 218, + + /** + * This indicates an error with OpenGL or DirectX context. + */ + CUDA_ERROR_INVALID_GRAPHICS_CONTEXT = 219, + + /** + * This indicates that an uncorrectable NVLink error was detected during the + * execution. + */ + CUDA_ERROR_NVLINK_UNCORRECTABLE = 220, + + /** + * This indicates that the PTX JIT compiler library was not found. + */ + CUDA_ERROR_JIT_COMPILER_NOT_FOUND = 221, + + /** + * This indicates that the provided PTX was compiled with an unsupported + * toolchain. + */ + + CUDA_ERROR_UNSUPPORTED_PTX_VERSION = 222, + + /** + * This indicates that the PTX JIT compilation was disabled. + */ + CUDA_ERROR_JIT_COMPILATION_DISABLED = 223, + + /** + * This indicates that the ::CUexecAffinityType passed to the API call is not + * supported by the active device. + */ + CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224, + + /** + * This indicates that the code to be compiled by the PTX JIT contains + * unsupported call to cudaDeviceSynchronize. + */ + CUDA_ERROR_UNSUPPORTED_DEVSIDE_SYNC = 225, + + /** + * This indicates that the device kernel source is invalid. This includes + * compilation/linker errors encountered in device code or user error. + */ + CUDA_ERROR_INVALID_SOURCE = 300, + + /** + * This indicates that the file specified was not found. + */ + CUDA_ERROR_FILE_NOT_FOUND = 301, + + /** + * This indicates that a link to a shared object failed to resolve. + */ + CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND = 302, + + /** + * This indicates that initialization of a shared object failed. + */ + CUDA_ERROR_SHARED_OBJECT_INIT_FAILED = 303, + + /** + * This indicates that an OS call failed. + */ + CUDA_ERROR_OPERATING_SYSTEM = 304, + + /** + * This indicates that a resource handle passed to the API call was not + * valid. Resource handles are opaque types like ::CUstream and ::CUevent. + */ + CUDA_ERROR_INVALID_HANDLE = 400, + + /** + * This indicates that a resource required by the API call is not in a + * valid state to perform the requested operation. + */ + CUDA_ERROR_ILLEGAL_STATE = 401, + + /** + * This indicates an attempt was made to introspect an object in a way that + * would discard semantically important information. This is either due to + * the object using functionality newer than the API version used to + * introspect it or omission of optional return arguments. + */ + CUDA_ERROR_LOSSY_QUERY = 402, + + /** + * This indicates that a named symbol was not found. Examples of symbols + * are global/constant variable names, driver function names, texture names, + * and surface names. + */ + CUDA_ERROR_NOT_FOUND = 500, + + /** + * This indicates that asynchronous operations issued previously have not + * completed yet. This result is not actually an error, but must be indicated + * differently than ::CUDA_SUCCESS (which indicates completion). Calls that + * may return this value include ::cuEventQuery() and ::cuStreamQuery(). + */ + CUDA_ERROR_NOT_READY = 600, + + /** + * While executing a kernel, the device encountered a + * load or store instruction on an invalid memory address. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_ILLEGAL_ADDRESS = 700, + + /** + * This indicates that a launch did not occur because it did not have + * appropriate resources. This error usually indicates that the user has + * attempted to pass too many arguments to the device kernel, or the + * kernel launch specifies too many threads for the kernel's register + * count. Passing arguments of the wrong size (i.e. a 64-bit pointer + * when a 32-bit int is expected) is equivalent to passing too many + * arguments and can also result in this error. + */ + CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES = 701, + + /** + * This indicates that the device kernel took too long to execute. This can + * only occur if timeouts are enabled - see the device attribute + * ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT for more information. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_LAUNCH_TIMEOUT = 702, + + /** + * This error indicates a kernel launch that uses an incompatible texturing + * mode. + */ + CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING = 703, + + /** + * This error indicates that a call to ::cuCtxEnablePeerAccess() is + * trying to re-enable peer access to a context which has already + * had peer access to it enabled. + */ + CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED = 704, + + /** + * This error indicates that ::cuCtxDisablePeerAccess() is + * trying to disable peer access which has not been enabled yet + * via ::cuCtxEnablePeerAccess(). + */ + CUDA_ERROR_PEER_ACCESS_NOT_ENABLED = 705, + + /** + * This error indicates that the primary context for the specified device + * has already been initialized. + */ + CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE = 708, + + /** + * This error indicates that the context current to the calling thread + * has been destroyed using ::cuCtxDestroy, or is a primary context which + * has not yet been initialized. + */ + CUDA_ERROR_CONTEXT_IS_DESTROYED = 709, + + /** + * A device-side assert triggered during kernel execution. The context + * cannot be used anymore, and must be destroyed. All existing device + * memory allocations from this context are invalid and must be + * reconstructed if the program is to continue using CUDA. + */ + CUDA_ERROR_ASSERT = 710, + + /** + * This error indicates that the hardware resources required to enable + * peer access have been exhausted for one or more of the devices + * passed to ::cuCtxEnablePeerAccess(). + */ + CUDA_ERROR_TOO_MANY_PEERS = 711, + + /** + * This error indicates that the memory range passed to ::cuMemHostRegister() + * has already been registered. + */ + CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED = 712, + + /** + * This error indicates that the pointer passed to ::cuMemHostUnregister() + * does not correspond to any currently registered memory region. + */ + CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED = 713, + + /** + * While executing a kernel, the device encountered a stack error. + * This can be due to stack corruption or exceeding the stack size limit. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_HARDWARE_STACK_ERROR = 714, + + /** + * While executing a kernel, the device encountered an illegal instruction. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_ILLEGAL_INSTRUCTION = 715, + + /** + * While executing a kernel, the device encountered a load or store + * instruction on a memory address which is not aligned. This leaves the + * process in an inconsistent state and any further CUDA work will return the + * same error. To continue using CUDA, the process must be terminated and + * relaunched. + */ + CUDA_ERROR_MISALIGNED_ADDRESS = 716, + + /** + * While executing a kernel, the device encountered an instruction + * which can only operate on memory locations in certain address spaces + * (global, shared, or local), but was supplied a memory address not + * belonging to an allowed address space. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_INVALID_ADDRESS_SPACE = 717, + + /** + * While executing a kernel, the device program counter wrapped its address + * space. This leaves the process in an inconsistent state and any further + * CUDA work will return the same error. To continue using CUDA, the process + * must be terminated and relaunched. + */ + CUDA_ERROR_INVALID_PC = 718, + + /** + * An exception occurred on the device while executing a kernel. Common + * causes include dereferencing an invalid device pointer and accessing + * out of bounds shared memory. Less common cases can be system specific - + * more information about these cases can be found in the system specific user + * guide. This leaves the process in an inconsistent state and any further + * CUDA work will return the same error. To continue using CUDA, the process + * must be terminated and relaunched. + */ + CUDA_ERROR_LAUNCH_FAILED = 719, + + /** + * This error indicates that the number of blocks launched per grid for a + * kernel that was launched via either ::cuLaunchCooperativeKernel or + * ::cuLaunchCooperativeKernelMultiDevice exceeds the maximum number of blocks + * as allowed by ::cuOccupancyMaxActiveBlocksPerMultiprocessor or + * ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags times the number of + * multiprocessors as specified by the device attribute + * ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT. + */ + CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE = 720, + + /** + * This error indicates that the attempted operation is not permitted. + */ + CUDA_ERROR_NOT_PERMITTED = 800, + + /** + * This error indicates that the attempted operation is not supported + * on the current system or device. + */ + CUDA_ERROR_NOT_SUPPORTED = 801, + + /** + * This error indicates that the system is not yet ready to start any CUDA + * work. To continue using CUDA, verify the system configuration is in a + * valid state and all required driver daemons are actively running. + * More information about this error can be found in the system specific + * user guide. + */ + CUDA_ERROR_SYSTEM_NOT_READY = 802, + + /** + * This error indicates that there is a mismatch between the versions of + * the display driver and the CUDA driver. Refer to the compatibility + * documentation for supported versions. + */ + CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803, + + /** + * This error indicates that the system was upgraded to run with forward + * compatibility but the visible hardware detected by CUDA does not support + * this configuration. Refer to the compatibility documentation for the + * supported hardware matrix or ensure that only supported hardware is visible + * during initialization via the CUDA_VISIBLE_DEVICES environment variable. + */ + CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE = 804, + + /** + * This error indicates that the MPS client failed to connect to the MPS + * control daemon or the MPS server. + */ + CUDA_ERROR_MPS_CONNECTION_FAILED = 805, + + /** + * This error indicates that the remote procedural call between the MPS server + * and the MPS client failed. + */ + CUDA_ERROR_MPS_RPC_FAILURE = 806, + + /** + * This error indicates that the MPS server is not ready to accept new MPS + * client requests. This error can be returned when the MPS server is in the + * process of recovering from a fatal failure. + */ + CUDA_ERROR_MPS_SERVER_NOT_READY = 807, + + /** + * This error indicates that the hardware resources required to create MPS + * client have been exhausted. + */ + CUDA_ERROR_MPS_MAX_CLIENTS_REACHED = 808, + + /** + * This error indicates the the hardware resources required to support device + * connections have been exhausted. + */ + CUDA_ERROR_MPS_MAX_CONNECTIONS_REACHED = 809, + + /** + * This error indicates that the MPS client has been terminated by the server. + * To continue using CUDA, the process must be terminated and relaunched. + */ + CUDA_ERROR_MPS_CLIENT_TERMINATED = 810, + + /** + * This error indicates that the module is using CUDA Dynamic Parallelism, but + * the current configuration, like MPS, does not support it. + */ + CUDA_ERROR_CDP_NOT_SUPPORTED = 811, + + /** + * This error indicates that a module contains an unsupported interaction + * between different versions of CUDA Dynamic Parallelism. + */ + CUDA_ERROR_CDP_VERSION_MISMATCH = 812, + + /** + * This error indicates that the operation is not permitted when + * the stream is capturing. + */ + CUDA_ERROR_STREAM_CAPTURE_UNSUPPORTED = 900, + + /** + * This error indicates that the current capture sequence on the stream + * has been invalidated due to a previous error. + */ + CUDA_ERROR_STREAM_CAPTURE_INVALIDATED = 901, + + /** + * This error indicates that the operation would have resulted in a merge + * of two independent capture sequences. + */ + CUDA_ERROR_STREAM_CAPTURE_MERGE = 902, + + /** + * This error indicates that the capture was not initiated in this stream. + */ + CUDA_ERROR_STREAM_CAPTURE_UNMATCHED = 903, + + /** + * This error indicates that the capture sequence contains a fork that was + * not joined to the primary stream. + */ + CUDA_ERROR_STREAM_CAPTURE_UNJOINED = 904, + + /** + * This error indicates that a dependency would have been created which + * crosses the capture sequence boundary. Only implicit in-stream ordering + * dependencies are allowed to cross the boundary. + */ + CUDA_ERROR_STREAM_CAPTURE_ISOLATION = 905, + + /** + * This error indicates a disallowed implicit dependency on a current capture + * sequence from cudaStreamLegacy. + */ + CUDA_ERROR_STREAM_CAPTURE_IMPLICIT = 906, + + /** + * This error indicates that the operation is not permitted on an event which + * was last recorded in a capturing stream. + */ + CUDA_ERROR_CAPTURED_EVENT = 907, + + /** + * A stream capture sequence not initiated with the + * ::CU_STREAM_CAPTURE_MODE_RELAXED argument to ::cuStreamBeginCapture was + * passed to ::cuStreamEndCapture in a different thread. + */ + CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD = 908, + + /** + * This error indicates that the timeout specified for the wait operation has + * lapsed. + */ + CUDA_ERROR_TIMEOUT = 909, + + /** + * This error indicates that the graph update was not performed because it + * included changes which violated constraints specific to instantiated graph + * update. + */ + CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE = 910, + + /** + * This indicates that an async error has occurred in a device outside of + * CUDA. If CUDA was waiting for an external device's signal before consuming + * shared data, the external device signaled an error indicating that the data + * is not valid for consumption. This leaves the process in an inconsistent + * state and any further CUDA work will return the same error. To continue + * using CUDA, the process must be terminated and relaunched. + */ + CUDA_ERROR_EXTERNAL_DEVICE = 911, + + /** + * Indicates a kernel launch error due to cluster misconfiguration. + */ + CUDA_ERROR_INVALID_CLUSTER_SIZE = 912, + + /** + * Indicates a function handle is not loaded when calling an API that requires + * a loaded function. + */ + CUDA_ERROR_FUNCTION_NOT_LOADED = 913, + + /** + * This error indicates one or more resources passed in are not valid resource + * types for the operation. + */ + CUDA_ERROR_INVALID_RESOURCE_TYPE = 914, + + /** + * This error indicates one or more resources are insufficient or + * non-applicable for the operation. + */ + CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION = 915, + + /** + * This indicates that an unknown internal error has occurred. + */ + CUDA_ERROR_UNKNOWN = 999 +} CUresult; + +/** + * P2P Attributes + */ +typedef enum CUdevice_P2PAttribute_enum { + CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK = + 0x01, /**< A relative value indicating the performance of the link between + two devices */ + CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED = 0x02, /**< P2P Access is enable */ + CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED = + 0x03, /**< Atomic operation over the link supported */ + CU_DEVICE_P2P_ATTRIBUTE_ACCESS_ACCESS_SUPPORTED = + 0x04, /**< \deprecated use + CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED instead */ + CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED = + 0x04 /**< Accessing CUDA arrays over the link supported */ +} CUdevice_P2PAttribute; + +/** + * CUDA stream callback + * \param hStream The stream the callback was added to, as passed to + * ::cuStreamAddCallback. May be NULL. \param status ::CUDA_SUCCESS or any + * persistent error on the stream. \param userData User parameter provided at + * registration. + */ +typedef void(CUDA_CB *CUstreamCallback)(CUstream hStream, CUresult status, + void *userData); + +/** + * Block size to per-block dynamic shared memory mapping for a certain + * kernel \param blockSize Block size of the kernel. + * + * \return The dynamic shared memory needed by a block. + */ +typedef size_t(CUDA_CB *CUoccupancyB2DSize)(int blockSize); + +/** + * If set, host memory is portable between CUDA contexts. + * Flag for ::cuMemHostAlloc() + */ +#define CU_MEMHOSTALLOC_PORTABLE 0x01 + +/** + * If set, host memory is mapped into CUDA address space and + * ::cuMemHostGetDevicePointer() may be called on the host pointer. + * Flag for ::cuMemHostAlloc() + */ +#define CU_MEMHOSTALLOC_DEVICEMAP 0x02 + +/** + * If set, host memory is allocated as write-combined - fast to write, + * faster to DMA, slow to read except via SSE4 streaming load instruction + * (MOVNTDQA). + * Flag for ::cuMemHostAlloc() + */ +#define CU_MEMHOSTALLOC_WRITECOMBINED 0x04 + +/** + * If set, host memory is portable between CUDA contexts. + * Flag for ::cuMemHostRegister() + */ +#define CU_MEMHOSTREGISTER_PORTABLE 0x01 + +/** + * If set, host memory is mapped into CUDA address space and + * ::cuMemHostGetDevicePointer() may be called on the host pointer. + * Flag for ::cuMemHostRegister() + */ +#define CU_MEMHOSTREGISTER_DEVICEMAP 0x02 + +/** + * If set, the passed memory pointer is treated as pointing to some + * memory-mapped I/O space, e.g. belonging to a third-party PCIe device. + * On Windows the flag is a no-op. + * On Linux that memory is marked as non cache-coherent for the GPU and + * is expected to be physically contiguous. It may return + * ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, + * ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. + * On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED + * is returned. + * Flag for ::cuMemHostRegister() + */ +#define CU_MEMHOSTREGISTER_IOMEMORY 0x04 + +/** + * If set, the passed memory pointer is treated as pointing to memory that is + * considered read-only by the device. On platforms without + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag + * is required in order to register memory mapped to the CPU as read-only. + * Support for the use of this flag can be queried from the device attribute + * ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag + * with a current context associated with a device that does not have this + * attribute set will cause ::cuMemHostRegister to error with + * ::CUDA_ERROR_NOT_SUPPORTED. + */ +#define CU_MEMHOSTREGISTER_READ_ONLY 0x08 + +/** + * 2D memory copy parameters + */ +typedef struct CUDA_MEMCPY2D_st { + size_t srcXInBytes; /**< Source X in bytes */ + size_t srcY; /**< Source Y */ + + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + size_t srcPitch; /**< Source pitch (ignored when src is array) */ + + size_t dstXInBytes; /**< Destination X in bytes */ + size_t dstY; /**< Destination Y */ + + CUmemorytype + dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + size_t dstPitch; /**< Destination pitch (ignored when dst is array) */ + + size_t WidthInBytes; /**< Width of 2D memory copy in bytes */ + size_t Height; /**< Height of 2D memory copy */ +} CUDA_MEMCPY2D_v2; +typedef CUDA_MEMCPY2D_v2 CUDA_MEMCPY2D; + +/** + * 3D memory copy parameters + */ +typedef struct CUDA_MEMCPY3D_st { + size_t srcXInBytes; /**< Source X in bytes */ + size_t srcY; /**< Source Y */ + size_t srcZ; /**< Source Z */ + size_t srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + void *reserved0; /**< Must be NULL */ + size_t srcPitch; /**< Source pitch (ignored when src is array) */ + size_t srcHeight; /**< Source height (ignored when src is array; may be 0 if + Depth==1) */ + + size_t dstXInBytes; /**< Destination X in bytes */ + size_t dstY; /**< Destination Y */ + size_t dstZ; /**< Destination Z */ + size_t dstLOD; /**< Destination LOD */ + CUmemorytype + dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + void *reserved1; /**< Must be NULL */ + size_t dstPitch; /**< Destination pitch (ignored when dst is array) */ + size_t dstHeight; /**< Destination height (ignored when dst is array; may be 0 + if Depth==1) */ + + size_t WidthInBytes; /**< Width of 3D memory copy in bytes */ + size_t Height; /**< Height of 3D memory copy */ + size_t Depth; /**< Depth of 3D memory copy */ +} CUDA_MEMCPY3D_v2; +typedef CUDA_MEMCPY3D_v2 CUDA_MEMCPY3D; + +/** + * 3D memory cross-context copy parameters + */ +typedef struct CUDA_MEMCPY3D_PEER_st { + size_t srcXInBytes; /**< Source X in bytes */ + size_t srcY; /**< Source Y */ + size_t srcZ; /**< Source Z */ + size_t srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + CUcontext srcContext; /**< Source context (ignored with srcMemoryType is + ::CU_MEMORYTYPE_ARRAY) */ + size_t srcPitch; /**< Source pitch (ignored when src is array) */ + size_t srcHeight; /**< Source height (ignored when src is array; may be 0 if + Depth==1) */ + + size_t dstXInBytes; /**< Destination X in bytes */ + size_t dstY; /**< Destination Y */ + size_t dstZ; /**< Destination Z */ + size_t dstLOD; /**< Destination LOD */ + CUmemorytype + dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + CUcontext dstContext; /**< Destination context (ignored with dstMemoryType is + ::CU_MEMORYTYPE_ARRAY) */ + size_t dstPitch; /**< Destination pitch (ignored when dst is array) */ + size_t dstHeight; /**< Destination height (ignored when dst is array; may be 0 + if Depth==1) */ + + size_t WidthInBytes; /**< Width of 3D memory copy in bytes */ + size_t Height; /**< Height of 3D memory copy */ + size_t Depth; /**< Depth of 3D memory copy */ +} CUDA_MEMCPY3D_PEER_v1; +typedef CUDA_MEMCPY3D_PEER_v1 CUDA_MEMCPY3D_PEER; + +/** + * Memcpy node parameters + */ +typedef struct CUDA_MEMCPY_NODE_PARAMS_st { + int flags; /**< Must be zero */ + int reserved; /**< Must be zero */ + CUcontext copyCtx; /**< Context on which to run the node */ + CUDA_MEMCPY3D copyParams; /**< Parameters for the memory copy */ +} CUDA_MEMCPY_NODE_PARAMS; + +/** + * Array descriptor + */ +typedef struct CUDA_ARRAY_DESCRIPTOR_st { + size_t Width; /**< Width of array */ + size_t Height; /**< Height of array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ +} CUDA_ARRAY_DESCRIPTOR_v2; +typedef CUDA_ARRAY_DESCRIPTOR_v2 CUDA_ARRAY_DESCRIPTOR; + +/** + * 3D array descriptor + */ +typedef struct CUDA_ARRAY3D_DESCRIPTOR_st { + size_t Width; /**< Width of 3D array */ + size_t Height; /**< Height of 3D array */ + size_t Depth; /**< Depth of 3D array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + unsigned int Flags; /**< Flags */ +} CUDA_ARRAY3D_DESCRIPTOR_v2; +typedef CUDA_ARRAY3D_DESCRIPTOR_v2 CUDA_ARRAY3D_DESCRIPTOR; + +/** + * Indicates that the layered sparse CUDA array or CUDA mipmapped array has a + * single mip tail region for all layers + */ +#define CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL 0x1 + +/** + * CUDA array sparse properties + */ +typedef struct CUDA_ARRAY_SPARSE_PROPERTIES_st { + struct { + unsigned int width; /**< Width of sparse tile in elements */ + unsigned int height; /**< Height of sparse tile in elements */ + unsigned int depth; /**< Depth of sparse tile in elements */ + } tileExtent; + + /** + * First mip level at which the mip tail begins. + */ + unsigned int miptailFirstLevel; + /** + * Total size of the mip tail. + */ + unsigned long long miptailSize; + /** + * Flags will either be zero or ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL + */ + unsigned int flags; + unsigned int reserved[4]; +} CUDA_ARRAY_SPARSE_PROPERTIES_v1; +typedef CUDA_ARRAY_SPARSE_PROPERTIES_v1 CUDA_ARRAY_SPARSE_PROPERTIES; + +/** + * CUDA array memory requirements + */ +typedef struct CUDA_ARRAY_MEMORY_REQUIREMENTS_st { + size_t size; /**< Total required memory size */ + size_t alignment; /**< alignment requirement */ + unsigned int reserved[4]; +} CUDA_ARRAY_MEMORY_REQUIREMENTS_v1; +typedef CUDA_ARRAY_MEMORY_REQUIREMENTS_v1 CUDA_ARRAY_MEMORY_REQUIREMENTS; + +/** + * CUDA Resource descriptor + */ +typedef struct CUDA_RESOURCE_DESC_st { + CUresourcetype resType; /**< Resource type */ + + union { + struct { + CUarray hArray; /**< CUDA array */ + } array; + struct { + CUmipmappedArray hMipmappedArray; /**< CUDA mipmapped array */ + } mipmap; + struct { + CUdeviceptr devPtr; /**< Device pointer */ + CUarray_format format; /**< Array format */ + unsigned int numChannels; /**< Channels per array element */ + size_t sizeInBytes; /**< Size in bytes */ + } linear; + struct { + CUdeviceptr devPtr; /**< Device pointer */ + CUarray_format format; /**< Array format */ + unsigned int numChannels; /**< Channels per array element */ + size_t width; /**< Width of the array in elements */ + size_t height; /**< Height of the array in elements */ + size_t pitchInBytes; /**< Pitch between two rows in bytes */ + } pitch2D; + struct { + int reserved[32]; + } reserved; + } res; + + unsigned int flags; /**< Flags (must be zero) */ +} CUDA_RESOURCE_DESC_v1; +typedef CUDA_RESOURCE_DESC_v1 CUDA_RESOURCE_DESC; + +/** + * Texture descriptor + */ +typedef struct CUDA_TEXTURE_DESC_st { + CUaddress_mode addressMode[3]; /**< Address modes */ + CUfilter_mode filterMode; /**< Filter mode */ + unsigned int flags; /**< Flags */ + unsigned int maxAnisotropy; /**< Maximum anisotropy ratio */ + CUfilter_mode mipmapFilterMode; /**< Mipmap filter mode */ + float mipmapLevelBias; /**< Mipmap level bias */ + float minMipmapLevelClamp; /**< Mipmap minimum level clamp */ + float maxMipmapLevelClamp; /**< Mipmap maximum level clamp */ + float borderColor[4]; /**< Border Color */ + int reserved[12]; +} CUDA_TEXTURE_DESC_v1; +typedef CUDA_TEXTURE_DESC_v1 CUDA_TEXTURE_DESC; + +/** + * Resource view format + */ +typedef enum CUresourceViewFormat_enum { + CU_RES_VIEW_FORMAT_NONE = + 0x00, /**< No resource view format (use underlying resource format) */ + CU_RES_VIEW_FORMAT_UINT_1X8 = 0x01, /**< 1 channel unsigned 8-bit integers */ + CU_RES_VIEW_FORMAT_UINT_2X8 = 0x02, /**< 2 channel unsigned 8-bit integers */ + CU_RES_VIEW_FORMAT_UINT_4X8 = 0x03, /**< 4 channel unsigned 8-bit integers */ + CU_RES_VIEW_FORMAT_SINT_1X8 = 0x04, /**< 1 channel signed 8-bit integers */ + CU_RES_VIEW_FORMAT_SINT_2X8 = 0x05, /**< 2 channel signed 8-bit integers */ + CU_RES_VIEW_FORMAT_SINT_4X8 = 0x06, /**< 4 channel signed 8-bit integers */ + CU_RES_VIEW_FORMAT_UINT_1X16 = + 0x07, /**< 1 channel unsigned 16-bit integers */ + CU_RES_VIEW_FORMAT_UINT_2X16 = + 0x08, /**< 2 channel unsigned 16-bit integers */ + CU_RES_VIEW_FORMAT_UINT_4X16 = + 0x09, /**< 4 channel unsigned 16-bit integers */ + CU_RES_VIEW_FORMAT_SINT_1X16 = 0x0a, /**< 1 channel signed 16-bit integers */ + CU_RES_VIEW_FORMAT_SINT_2X16 = 0x0b, /**< 2 channel signed 16-bit integers */ + CU_RES_VIEW_FORMAT_SINT_4X16 = 0x0c, /**< 4 channel signed 16-bit integers */ + CU_RES_VIEW_FORMAT_UINT_1X32 = + 0x0d, /**< 1 channel unsigned 32-bit integers */ + CU_RES_VIEW_FORMAT_UINT_2X32 = + 0x0e, /**< 2 channel unsigned 32-bit integers */ + CU_RES_VIEW_FORMAT_UINT_4X32 = + 0x0f, /**< 4 channel unsigned 32-bit integers */ + CU_RES_VIEW_FORMAT_SINT_1X32 = 0x10, /**< 1 channel signed 32-bit integers */ + CU_RES_VIEW_FORMAT_SINT_2X32 = 0x11, /**< 2 channel signed 32-bit integers */ + CU_RES_VIEW_FORMAT_SINT_4X32 = 0x12, /**< 4 channel signed 32-bit integers */ + CU_RES_VIEW_FORMAT_FLOAT_1X16 = 0x13, /**< 1 channel 16-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_2X16 = 0x14, /**< 2 channel 16-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_4X16 = 0x15, /**< 4 channel 16-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_1X32 = 0x16, /**< 1 channel 32-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_2X32 = 0x17, /**< 2 channel 32-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_4X32 = 0x18, /**< 4 channel 32-bit floating point */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC1 = 0x19, /**< Block compressed 1 */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC2 = 0x1a, /**< Block compressed 2 */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC3 = 0x1b, /**< Block compressed 3 */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC4 = 0x1c, /**< Block compressed 4 unsigned */ + CU_RES_VIEW_FORMAT_SIGNED_BC4 = 0x1d, /**< Block compressed 4 signed */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC5 = 0x1e, /**< Block compressed 5 unsigned */ + CU_RES_VIEW_FORMAT_SIGNED_BC5 = 0x1f, /**< Block compressed 5 signed */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC6H = + 0x20, /**< Block compressed 6 unsigned half-float */ + CU_RES_VIEW_FORMAT_SIGNED_BC6H = + 0x21, /**< Block compressed 6 signed half-float */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC7 = 0x22 /**< Block compressed 7 */ +} CUresourceViewFormat; + +/** + * Resource view descriptor + */ +typedef struct CUDA_RESOURCE_VIEW_DESC_st { + CUresourceViewFormat format; /**< Resource view format */ + size_t width; /**< Width of the resource view */ + size_t height; /**< Height of the resource view */ + size_t depth; /**< Depth of the resource view */ + unsigned int firstMipmapLevel; /**< First defined mipmap level */ + unsigned int lastMipmapLevel; /**< Last defined mipmap level */ + unsigned int firstLayer; /**< First layer index */ + unsigned int lastLayer; /**< Last layer index */ + unsigned int reserved[16]; +} CUDA_RESOURCE_VIEW_DESC_v1; +typedef CUDA_RESOURCE_VIEW_DESC_v1 CUDA_RESOURCE_VIEW_DESC; + +/** + * Size of tensor map descriptor + */ +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +/** + * Tensor map descriptor. Requires compiler support for aligning to 64 bytes. + */ +typedef struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; + +/** + * Tensor map data type + */ +typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, + CU_TENSOR_MAP_DATA_TYPE_UINT16, + CU_TENSOR_MAP_DATA_TYPE_UINT32, + CU_TENSOR_MAP_DATA_TYPE_INT32, + CU_TENSOR_MAP_DATA_TYPE_UINT64, + CU_TENSOR_MAP_DATA_TYPE_INT64, + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ +} CUtensorMapDataType; + +/** + * Tensor map interleave layout type + */ +typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B +} CUtensorMapInterleave; + +/** + * Tensor map swizzling mode of shared memory banks + */ +typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B, +} CUtensorMapSwizzle; + +/** + * Tensor map L2 promotion type + */ +typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B +} CUtensorMapL2promotion; + +/** + * Tensor map out-of-bounds fill type + */ +typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA +} CUtensorMapFloatOOBfill; + +/** + * GPU Direct v3 tokens + */ +typedef struct CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_st { + unsigned long long p2pToken; + unsigned int vaSpaceToken; +} CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_v1; +typedef CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_v1 CUDA_POINTER_ATTRIBUTE_P2P_TOKENS; + +/** + * Access flags that specify the level of access the current context's device + * has on the memory referenced. + */ +typedef enum CUDA_POINTER_ATTRIBUTE_ACCESS_FLAGS_enum { + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_NONE = + 0x0, /**< No access, meaning the device cannot access this memory at all, + thus must be staged through accessible memory in order to complete + certain operations */ + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_READ = + 0x1, /**< Read-only access, meaning writes to this memory are considered + invalid accesses and thus return error in that case. */ + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_READWRITE = + 0x3 /**< Read-write access, the device has full read-write access to the + memory */ +} CUDA_POINTER_ATTRIBUTE_ACCESS_FLAGS; + +/** + * Kernel launch parameters + */ +typedef struct CUDA_LAUNCH_PARAMS_st { + CUfunction function; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block + in bytes */ + CUstream hStream; /**< Stream identifier */ + void **kernelParams; /**< Array of pointers to kernel parameters */ +} CUDA_LAUNCH_PARAMS_v1; +typedef CUDA_LAUNCH_PARAMS_v1 CUDA_LAUNCH_PARAMS; + +/** + * External memory handle types + */ +typedef enum CUexternalMemoryHandleType_enum { + /** + * Handle is an opaque file descriptor + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, + /** + * Handle is an opaque shared NT handle + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, + /** + * Handle is an opaque, globally shared handle + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + /** + * Handle is a D3D12 heap object + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, + /** + * Handle is a D3D12 committed resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5, + /** + * Handle is a shared NT handle to a D3D11 resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE = 6, + /** + * Handle is a globally shared handle to a D3D11 resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT = 7, + /** + * Handle is an NvSciBuf object + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF = 8 +} CUexternalMemoryHandleType; + +/** + * Indicates that the external memory object is a dedicated resource + */ +#define CUDA_EXTERNAL_MEMORY_DEDICATED 0x1 + +/** When the \p flags parameter of ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS + * contains this flag, it indicates that signaling an external semaphore object + * should skip performing appropriate memory synchronization operations over all + * the external memory objects that are imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, which otherwise are performed by + * default to ensure data coherency with other importers of the same NvSciBuf + * memory objects. + */ +#define CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC 0x01 + +/** When the \p flags parameter of ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS + * contains this flag, it indicates that waiting on an external semaphore object + * should skip performing appropriate memory synchronization operations over all + * the external memory objects that are imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, which otherwise are performed by + * default to ensure data coherency with other importers of the same NvSciBuf + * memory objects. + */ +#define CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC 0x02 + +/** + * When \p flags of ::cuDeviceGetNvSciSyncAttributes is set to this, + * it indicates that application needs signaler specific NvSciSyncAttr + * to be filled by ::cuDeviceGetNvSciSyncAttributes. + */ +#define CUDA_NVSCISYNC_ATTR_SIGNAL 0x1 + +/** + * When \p flags of ::cuDeviceGetNvSciSyncAttributes is set to this, + * it indicates that application needs waiter specific NvSciSyncAttr + * to be filled by ::cuDeviceGetNvSciSyncAttributes. + */ +#define CUDA_NVSCISYNC_ATTR_WAIT 0x2 +/** + * External memory handle descriptor + */ +typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { + /** + * Type of the handle + */ + CUexternalMemoryHandleType type; + union { + /** + * File descriptor referencing the memory object. Valid + * when type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD + */ + int fd; + /** + * Win32 handle referencing the semaphore object. Valid when + * type is one of the following: + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT + * Exactly one of 'handle' and 'name' must be non-NULL. If + * type is one of the following: + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT + * then 'name' must be NULL. + */ + struct { + /** + * Valid NT handle. Must be NULL if 'name' is non-NULL + */ + void *handle; + /** + * Name of a valid memory object. + * Must be NULL if 'handle' is non-NULL. + */ + const void *name; + } win32; + /** + * A handle representing an NvSciBuf Object. Valid when type + * is ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF + */ + const void *nvSciBufObject; + } handle; + /** + * Size of the memory allocation + */ + unsigned long long size; + /** + * Flags must either be zero or ::CUDA_EXTERNAL_MEMORY_DEDICATED + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_MEMORY_HANDLE_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_HANDLE_DESC_v1 CUDA_EXTERNAL_MEMORY_HANDLE_DESC; + +/** + * External memory buffer descriptor + */ +typedef struct CUDA_EXTERNAL_MEMORY_BUFFER_DESC_st { + /** + * Offset into the memory object where the buffer's base is + */ + unsigned long long offset; + /** + * Size of the buffer + */ + unsigned long long size; + /** + * Flags reserved for future use. Must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_MEMORY_BUFFER_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_BUFFER_DESC_v1 CUDA_EXTERNAL_MEMORY_BUFFER_DESC; + +/** + * External memory mipmap descriptor + */ +typedef struct CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_st { + /** + * Offset into the memory object where the base level of the + * mipmap chain is. + */ + unsigned long long offset; + /** + * Format, dimension and type of base level of the mipmap chain + */ + CUDA_ARRAY3D_DESCRIPTOR arrayDesc; + /** + * Total number of levels in the mipmap chain + */ + unsigned int numLevels; + unsigned int reserved[16]; +} CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_v1 + CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC; + +/** + * External semaphore handle types + */ +typedef enum CUexternalSemaphoreHandleType_enum { + /** + * Handle is an opaque file descriptor + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, + /** + * Handle is an opaque shared NT handle + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, + /** + * Handle is an opaque, globally shared handle + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + /** + * Handle is a shared NT handle referencing a D3D12 fence object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4, + /** + * Handle is a shared NT handle referencing a D3D11 fence object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE = 5, + /** + * Opaque handle to NvSciSync Object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC = 6, + /** + * Handle is a shared NT handle referencing a D3D11 keyed mutex object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX = 7, + /** + * Handle is a globally shared handle referencing a D3D11 keyed mutex object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT = 8, + /** + * Handle is an opaque file descriptor referencing a timeline semaphore + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD = 9, + /** + * Handle is an opaque shared NT handle referencing a timeline semaphore + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 = 10 +} CUexternalSemaphoreHandleType; + +/** + * External semaphore handle descriptor + */ +typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { + /** + * Type of the handle + */ + CUexternalSemaphoreHandleType type; + union { + /** + * File descriptor referencing the semaphore object. Valid + * when type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD + */ + int fd; + /** + * Win32 handle referencing the semaphore object. Valid when + * type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * Exactly one of 'handle' and 'name' must be non-NULL. If + * type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then 'name' must be NULL. + */ + struct { + /** + * Valid NT handle. Must be NULL if 'name' is non-NULL + */ + void *handle; + /** + * Name of a valid synchronization primitive. + * Must be NULL if 'handle' is non-NULL. + */ + const void *name; + } win32; + /** + * Valid NvSciSyncObj. Must be non NULL + */ + const void *nvSciSyncObj; + } handle; + /** + * Flags reserved for the future. Must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_v1 + CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; + +/** + * External semaphore signal parameters + */ +typedef struct CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_st { + struct { + /** + * Parameters for fence objects + */ + struct { + /** + * Value of fence to be signaled + */ + unsigned long long value; + } fence; + union { + /** + * Pointer to NvSciSyncFence. Valid if ::CUexternalSemaphoreHandleType + * is of type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC. + */ + void *fence; + unsigned long long reserved; + } nvSciSync; + /** + * Parameters for keyed mutex objects + */ + struct { + /** + * Value of key to release the mutex with + */ + unsigned long long key; + } keyedMutex; + unsigned int reserved[12]; + } params; + /** + * Only when ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS is used to + * signal a ::CUexternalSemaphore of type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, the valid flag is + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC which indicates + * that while signaling the ::CUexternalSemaphore, no memory synchronization + * operations should be performed for any external memory object imported + * as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. + * For all other types of ::CUexternalSemaphore, flags must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_v1 + CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS; + +/** + * External semaphore wait parameters + */ +typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st { + struct { + /** + * Parameters for fence objects + */ + struct { + /** + * Value of fence to be waited on + */ + unsigned long long value; + } fence; + /** + * Pointer to NvSciSyncFence. Valid if CUexternalSemaphoreHandleType + * is of type CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC. + */ + union { + void *fence; + unsigned long long reserved; + } nvSciSync; + /** + * Parameters for keyed mutex objects + */ + struct { + /** + * Value of key to acquire the mutex with + */ + unsigned long long key; + /** + * Timeout in milliseconds to wait to acquire the mutex + */ + unsigned int timeoutMs; + } keyedMutex; + unsigned int reserved[10]; + } params; + /** + * Only when ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS is used to wait on + * a ::CUexternalSemaphore of type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, the valid flag is + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC which indicates that + * while waiting for the ::CUexternalSemaphore, no memory synchronization + * operations should be performed for any external memory object imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. For all other types of + * ::CUexternalSemaphore, flags must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_v1 + CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS; + +/** + * Semaphore signal node parameters + */ +typedef struct CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_st { + CUexternalSemaphore *extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS + *paramsArray; /**< Array of external semaphore signal parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in + extSemArray and paramsArray. */ +} CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v1; +typedef CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v1 CUDA_EXT_SEM_SIGNAL_NODE_PARAMS; + +/** + * Semaphore signal node parameters + */ +typedef struct CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v2_st { + CUexternalSemaphore *extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS + *paramsArray; /**< Array of external semaphore signal parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in + extSemArray and paramsArray. */ +} CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v2; + +/** + * Semaphore wait node parameters + */ +typedef struct CUDA_EXT_SEM_WAIT_NODE_PARAMS_st { + CUexternalSemaphore *extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS + *paramsArray; /**< Array of external semaphore wait parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in + extSemArray and paramsArray. */ +} CUDA_EXT_SEM_WAIT_NODE_PARAMS_v1; +typedef CUDA_EXT_SEM_WAIT_NODE_PARAMS_v1 CUDA_EXT_SEM_WAIT_NODE_PARAMS; + +/** + * Semaphore wait node parameters + */ +typedef struct CUDA_EXT_SEM_WAIT_NODE_PARAMS_v2_st { + CUexternalSemaphore *extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS + *paramsArray; /**< Array of external semaphore wait parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in + extSemArray and paramsArray. */ +} CUDA_EXT_SEM_WAIT_NODE_PARAMS_v2; + +typedef unsigned long long CUmemGenericAllocationHandle_v1; +typedef CUmemGenericAllocationHandle_v1 CUmemGenericAllocationHandle; + +/** + * Flags for specifying particular handle types + */ +typedef enum CUmemAllocationHandleType_enum { + CU_MEM_HANDLE_TYPE_NONE = 0x0, /**< Does not allow any export mechanism. > */ + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = + 0x1, /**< Allows a file descriptor to be used for exporting. Permitted + only on POSIX systems. (int) */ + CU_MEM_HANDLE_TYPE_WIN32 = + 0x2, /**< Allows a Win32 NT handle to be used for exporting. (HANDLE) */ + CU_MEM_HANDLE_TYPE_WIN32_KMT = 0x4, /**< Allows a Win32 KMT handle to be used + for exporting. (D3DKMT_HANDLE) */ + CU_MEM_HANDLE_TYPE_FABRIC = 0x8, /**< Allows a fabric handle to be used for + exporting. (CUmemFabricHandle)*/ + CU_MEM_HANDLE_TYPE_MAX = 0x7FFFFFFF +} CUmemAllocationHandleType; + +/** + * Specifies the memory protection flags for mapping. + */ +typedef enum CUmemAccess_flags_enum { + CU_MEM_ACCESS_FLAGS_PROT_NONE = + 0x0, /**< Default, make the address range not accessible */ + CU_MEM_ACCESS_FLAGS_PROT_READ = + 0x1, /**< Make the address range read accessible */ + CU_MEM_ACCESS_FLAGS_PROT_READWRITE = + 0x3, /**< Make the address range read-write accessible */ + CU_MEM_ACCESS_FLAGS_PROT_MAX = 0x7FFFFFFF +} CUmemAccess_flags; + +/** + * Specifies the type of location + */ +typedef enum CUmemLocationType_enum { + CU_MEM_LOCATION_TYPE_INVALID = 0x0, + CU_MEM_LOCATION_TYPE_DEVICE = + 0x1, /**< Location is a device location, thus id is a device ordinal */ + CU_MEM_LOCATION_TYPE_HOST = 0x2, /**< Location is host, id is ignored */ + CU_MEM_LOCATION_TYPE_HOST_NUMA = + 0x3, /**< Location is a host NUMA node, thus id is a host NUMA node id */ + CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT = + 0x4, /**< Location is a host NUMA node of the current thread, id is + ignored */ + CU_MEM_LOCATION_TYPE_MAX = 0x7FFFFFFF +} CUmemLocationType; + +/** + * Defines the allocation types available + */ +typedef enum CUmemAllocationType_enum { + CU_MEM_ALLOCATION_TYPE_INVALID = 0x0, + + /** This allocation type is 'pinned', i.e. cannot migrate from its current + * location while the application is actively using it + */ + CU_MEM_ALLOCATION_TYPE_PINNED = 0x1, + CU_MEM_ALLOCATION_TYPE_MAX = 0x7FFFFFFF +} CUmemAllocationType; + +/** + * Flag for requesting different optimal and required granularities for an + * allocation. + */ +typedef enum CUmemAllocationGranularity_flags_enum { + CU_MEM_ALLOC_GRANULARITY_MINIMUM = + 0x0, /**< Minimum required granularity for allocation */ + CU_MEM_ALLOC_GRANULARITY_RECOMMENDED = + 0x1 /**< Recommended granularity for allocation for best performance */ +} CUmemAllocationGranularity_flags; + +/** + * Specifies the handle type for address range + */ +typedef enum CUmemRangeHandleType_enum { + CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD = 0x1, + CU_MEM_RANGE_HANDLE_TYPE_MAX = 0x7FFFFFFF +} CUmemRangeHandleType; + +/** + * Sparse subresource types + */ +typedef enum CUarraySparseSubresourceType_enum { + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL = 0, + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL = 1 +} CUarraySparseSubresourceType; + +/** + * Memory operation types + */ +typedef enum CUmemOperationType_enum { + CU_MEM_OPERATION_TYPE_MAP = 1, + CU_MEM_OPERATION_TYPE_UNMAP = 2 +} CUmemOperationType; + +/** + * Memory handle types + */ +typedef enum CUmemHandleType_enum { + CU_MEM_HANDLE_TYPE_GENERIC = 0 +} CUmemHandleType; + +/** + * Specifies the CUDA array or CUDA mipmapped array memory mapping information + */ +typedef struct CUarrayMapInfo_st { + CUresourcetype resourceType; /**< Resource type */ + + union { + CUmipmappedArray mipmap; + CUarray array; + } resource; + + CUarraySparseSubresourceType subresourceType; /**< Sparse subresource type */ + + union { + struct { + unsigned int level; /**< For CUDA mipmapped arrays must a valid mipmap + level. For CUDA arrays must be zero */ + unsigned int layer; /**< For CUDA layered arrays must be a valid layer + index. Otherwise, must be zero */ + unsigned int offsetX; /**< Starting X offset in elements */ + unsigned int offsetY; /**< Starting Y offset in elements */ + unsigned int offsetZ; /**< Starting Z offset in elements */ + unsigned int extentWidth; /**< Width in elements */ + unsigned int extentHeight; /**< Height in elements */ + unsigned int extentDepth; /**< Depth in elements */ + } sparseLevel; + struct { + unsigned int layer; /**< For CUDA layered arrays must be a valid layer + index. Otherwise, must be zero */ + unsigned long long offset; /**< Offset within mip tail */ + unsigned long long size; /**< Extent in bytes */ + } miptail; + } subresource; + + CUmemOperationType memOperationType; /**< Memory operation type */ + CUmemHandleType memHandleType; /**< Memory handle type */ + + union { + CUmemGenericAllocationHandle memHandle; + } memHandle; + + unsigned long long offset; /**< Offset within the memory */ + unsigned int deviceBitMask; /**< Device ordinal bit mask */ + unsigned int flags; /**< flags for future use, must be zero now. */ + unsigned int reserved[2]; /**< Reserved for future use, must be zero now. */ +} CUarrayMapInfo_v1; +typedef CUarrayMapInfo_v1 CUarrayMapInfo; + +/** + * Specifies a memory location. + */ +typedef struct CUmemLocation_st { + CUmemLocationType type; /**< Specifies the location type, which modifies the + meaning of id. */ + int id; /**< identifier for a given this location's ::CUmemLocationType. */ +} CUmemLocation_v1; +typedef CUmemLocation_v1 CUmemLocation; + +/** + * Specifies compression attribute for an allocation. + */ +typedef enum CUmemAllocationCompType_enum { + CU_MEM_ALLOCATION_COMP_NONE = 0x0, /**< Allocating non-compressible memory */ + CU_MEM_ALLOCATION_COMP_GENERIC = 0x1 /**< Allocating compressible memory */ +} CUmemAllocationCompType; + +/** + * This flag if set indicates that the memory will be used as a tile pool. + */ +#define CU_MEM_CREATE_USAGE_TILE_POOL 0x1 + +/** + * Specifies the allocation properties for a allocation. + */ +typedef struct CUmemAllocationProp_st { + /** Allocation type */ + CUmemAllocationType type; + /** requested ::CUmemAllocationHandleType */ + CUmemAllocationHandleType requestedHandleTypes; + /** Location of allocation */ + CUmemLocation location; + /** + * Windows-specific POBJECT_ATTRIBUTES required when + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object attributes structure + * includes security attributes that define + * the scope of which exported allocations may be transferred to other + * processes. In all other cases, this field is required to be zero. + */ + void *win32HandleMetaData; + struct { + /** + * Allocation hint for requesting compressible memory. + * On devices that support Compute Data Compression, compressible + * memory can be used to accelerate accesses to data with unstructured + * sparsity and other compressible data patterns. Applications are + * expected to query allocation property of the handle obtained with + * ::cuMemCreate using ::cuMemGetAllocationPropertiesFromHandle to + * validate if the obtained allocation is compressible or not. Note that + * compressed memory may not be mappable on all devices. + */ + unsigned char compressionType; + unsigned char gpuDirectRDMACapable; + /** Bitmask indicating intended usage for this allocation */ + unsigned short usage; + unsigned char reserved[4]; + } allocFlags; +} CUmemAllocationProp_v1; +typedef CUmemAllocationProp_v1 CUmemAllocationProp; + +/** + * Flags for querying different granularities for a multicast object + */ +typedef enum CUmulticastGranularity_flags_enum { + CU_MULTICAST_GRANULARITY_MINIMUM = 0x0, /**< Minimum required granularity */ + CU_MULTICAST_GRANULARITY_RECOMMENDED = + 0x1 /**< Recommended granularity for best performance */ +} CUmulticastGranularity_flags; + +/** + * Specifies the properties for a multicast object. + */ +typedef struct CUmulticastObjectProp_st { + /** + * The number of devices in the multicast team that will bind memory to this + * object + */ + unsigned int numDevices; + /** + * The maximum amount of memory that can be bound to this multicast object + * per device + */ + size_t size; + /** + * Bitmask of exportable handle types (see ::CUmemAllocationHandleType) for + * this object + */ + unsigned long long handleTypes; + /** + * Flags for future use, must be zero now + */ + unsigned long long flags; +} CUmulticastObjectProp_v1; +typedef CUmulticastObjectProp_v1 CUmulticastObjectProp; + +/** + * Memory access descriptor + */ +typedef struct CUmemAccessDesc_st { + CUmemLocation location; /**< Location on which the request is to change it's + accessibility */ + CUmemAccess_flags + flags; /**< ::CUmemProt accessibility flags to set on the request */ +} CUmemAccessDesc_v1; +typedef CUmemAccessDesc_v1 CUmemAccessDesc; + +/** + * CUDA Graph Update error types + */ +typedef enum CUgraphExecUpdateResult_enum { + CU_GRAPH_EXEC_UPDATE_SUCCESS = 0x0, /**< The update succeeded */ + CU_GRAPH_EXEC_UPDATE_ERROR = + 0x1, /**< The update failed for an unexpected reason which is described in + the return value of the function */ + CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED = + 0x2, /**< The update failed because the topology changed */ + CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED = + 0x3, /**< The update failed because a node type changed */ + CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED = + 0x4, /**< The update failed because the function of a kernel node changed + (CUDA driver < 11.2) */ + CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED = + 0x5, /**< The update failed because the parameters changed in a way that + is not supported */ + CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED = + 0x6, /**< The update failed because something about the node is not + supported */ + CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE = + 0x7, /**< The update failed because the function of a kernel node changed + in an unsupported way */ + CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED = + 0x8 /**< The update failed because the node attributes changed in a way + that is not supported */ +} CUgraphExecUpdateResult; + +/** + * Result information returned by cuGraphExecUpdate + */ +typedef struct CUgraphExecUpdateResultInfo_st { + /** + * Gives more specific detail when a cuda graph update fails. + */ + CUgraphExecUpdateResult result; + + /** + * The "to node" of the error edge when the topologies do not match. + * The error node when the error is associated with a specific node. + * NULL when the error is generic. + */ + CUgraphNode errorNode; + + /** + * The from node of error edge when the topologies do not match. Otherwise + * NULL. + */ + CUgraphNode errorFromNode; +} CUgraphExecUpdateResultInfo_v1; +typedef CUgraphExecUpdateResultInfo_v1 CUgraphExecUpdateResultInfo; + +/** + * CUDA memory pool attributes + */ +typedef enum CUmemPool_attribute_enum { + /** + * (value type = int) + * Allow cuMemAllocAsync to use memory asynchronously freed + * in another streams as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + */ + CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES = 1, + + /** + * (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + */ + CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, + + /** + * (value type = int) + * Allow cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by cuFreeAsync (default enabled). + */ + CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES, + + /** + * (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + */ + CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, + + /** + * (value type = cuuint64_t) + * Amount of backing memory currently allocated for the mempool. + */ + CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of backing memory allocated for the mempool since the + * last time it was reset. High watermark can only be reset to zero. + */ + CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, + + /** + * (value type = cuuint64_t) + * Amount of memory from the pool that is currently in use by the application. + */ + CU_MEMPOOL_ATTR_USED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of the amount of memory from the pool that was in use by the + * application since the last time it was reset. High watermark can only be + * reset to zero. + */ + CU_MEMPOOL_ATTR_USED_MEM_HIGH +} CUmemPool_attribute; + +/** + * Specifies the properties of allocations made from the pool. + */ +typedef struct CUmemPoolProps_st { + CUmemAllocationType + allocType; /**< Allocation type. Currently must be specified as + CU_MEM_ALLOCATION_TYPE_PINNED */ + CUmemAllocationHandleType + handleTypes; /**< Handle types that will be supported by allocations from + the pool. */ + CUmemLocation location; /**< Location where allocations should reside. */ + /** + * Windows-specific LPSECURITYATTRIBUTES required when + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines + * the scope of which exported allocations may be transferred to other + * processes. In all other cases, this field is required to be zero. + */ + void *win32SecurityAttributes; + size_t maxSize; /**< Maximum pool size. When set to 0, defaults to a system + dependent value. */ + unsigned char reserved[56]; /**< reserved for future use, must be 0 */ +} CUmemPoolProps_v1; +typedef CUmemPoolProps_v1 CUmemPoolProps; + +/** + * Opaque data for exporting a pool allocation + */ +typedef struct CUmemPoolPtrExportData_st { + unsigned char reserved[64]; +} CUmemPoolPtrExportData_v1; +typedef CUmemPoolPtrExportData_v1 CUmemPoolPtrExportData; + +/** + * Memory allocation node parameters + */ +typedef struct CUDA_MEM_ALLOC_NODE_PARAMS_v1_st { + /** + * in: location where the allocation should reside (specified in ::location). + * ::handleTypes must be ::CU_MEM_HANDLE_TYPE_NONE. IPC is not supported. + */ + CUmemPoolProps poolProps; + const CUmemAccessDesc + *accessDescs; /**< in: array of memory access descriptors. Used to + describe peer GPU access */ + size_t accessDescCount; /**< in: number of memory access descriptors. Must + not exceed the number of GPUs. */ + size_t bytesize; /**< in: size in bytes of the requested allocation */ + CUdeviceptr dptr; /**< out: address of the allocation returned by CUDA */ +} CUDA_MEM_ALLOC_NODE_PARAMS_v1; +typedef CUDA_MEM_ALLOC_NODE_PARAMS_v1 CUDA_MEM_ALLOC_NODE_PARAMS; + +/** + * Memory allocation node parameters + */ +typedef struct CUDA_MEM_ALLOC_NODE_PARAMS_v2_st { + /** + * in: location where the allocation should reside (specified in ::location). + * ::handleTypes must be ::CU_MEM_HANDLE_TYPE_NONE. IPC is not supported. + */ + CUmemPoolProps poolProps; + const CUmemAccessDesc + *accessDescs; /**< in: array of memory access descriptors. Used to + describe peer GPU access */ + size_t accessDescCount; /**< in: number of memory access descriptors. Must + not exceed the number of GPUs. */ + size_t bytesize; /**< in: size in bytes of the requested allocation */ + CUdeviceptr dptr; /**< out: address of the allocation returned by CUDA */ +} CUDA_MEM_ALLOC_NODE_PARAMS_v2; + +/** + * Memory free node parameters + */ +typedef struct CUDA_MEM_FREE_NODE_PARAMS_st { + CUdeviceptr dptr; /**< in: the pointer to free */ +} CUDA_MEM_FREE_NODE_PARAMS; + +typedef enum CUgraphMem_attribute_enum { + /** + * (value type = cuuint64_t) + * Amount of memory, in bytes, currently associated with graphs + */ + CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + */ + CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, + + /** + * (value type = cuuint64_t) + * Amount of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + */ + CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + */ + CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH +} CUgraphMem_attribute; + +/** + * Child graph node parameters + */ +typedef struct CUDA_CHILD_GRAPH_NODE_PARAMS_st { + CUgraph graph; /**< The child graph to clone into the node for node creation, + or a handle to the graph owned by the node for node query */ +} CUDA_CHILD_GRAPH_NODE_PARAMS; + +/** + * Event record node parameters + */ +typedef struct CUDA_EVENT_RECORD_NODE_PARAMS_st { + CUevent event; /**< The event to record when the node executes */ +} CUDA_EVENT_RECORD_NODE_PARAMS; + +/** + * Event wait node parameters + */ +typedef struct CUDA_EVENT_WAIT_NODE_PARAMS_st { + CUevent event; /**< The event to wait on from the node */ +} CUDA_EVENT_WAIT_NODE_PARAMS; + +/** + * Graph node parameters. See ::cuGraphAddNode. + */ +typedef struct CUgraphNodeParams_st { + CUgraphNodeType type; /**< Type of the node */ + int reserved0[3]; /**< Reserved. Must be zero. */ + + union { + long long reserved1[29]; /**< Padding. Unused bytes must be zero. */ + CUDA_KERNEL_NODE_PARAMS_v3 kernel; /**< Kernel node parameters. */ + CUDA_MEMCPY_NODE_PARAMS memcpy; /**< Memcpy node parameters. */ + CUDA_MEMSET_NODE_PARAMS_v2 memset; /**< Memset node parameters. */ + CUDA_HOST_NODE_PARAMS_v2 host; /**< Host node parameters. */ + CUDA_CHILD_GRAPH_NODE_PARAMS graph; /**< Child graph node parameters. */ + CUDA_EVENT_WAIT_NODE_PARAMS eventWait; /**< Event wait node parameters. */ + CUDA_EVENT_RECORD_NODE_PARAMS + eventRecord; /**< Event record node parameters. */ + CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v2 + extSemSignal; /**< External semaphore signal node parameters. */ + CUDA_EXT_SEM_WAIT_NODE_PARAMS_v2 + extSemWait; /**< External semaphore wait node parameters. */ + CUDA_MEM_ALLOC_NODE_PARAMS_v2 + alloc; /**< Memory allocation node parameters. */ + CUDA_MEM_FREE_NODE_PARAMS free; /**< Memory free node parameters. */ + CUDA_BATCH_MEM_OP_NODE_PARAMS_v2 memOp; /**< MemOp node parameters. */ + CUDA_CONDITIONAL_NODE_PARAMS + conditional; /**< Conditional node parameters. */ + }; + + long long reserved2; /**< Reserved bytes. Must be zero. */ +} CUgraphNodeParams; + +/** + * If set, each kernel launched as part of + * ::cuLaunchCooperativeKernelMultiDevice only waits for prior work in the + * stream corresponding to that GPU to complete before the kernel begins + * execution. + */ +#define CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_PRE_LAUNCH_SYNC 0x01 + +/** + * If set, any subsequent work pushed in a stream that participated in a call to + * ::cuLaunchCooperativeKernelMultiDevice will only wait for the kernel launched + * on the GPU corresponding to that stream to complete before it begins + * execution. + */ +#define CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_POST_LAUNCH_SYNC 0x02 + +/** + * If set, the CUDA array is a collection of layers, where each layer is either + * a 1D or a 2D array and the Depth member of CUDA_ARRAY3D_DESCRIPTOR specifies + * the number of layers, not the depth of a 3D array. + */ +#define CUDA_ARRAY3D_LAYERED 0x01 + +/** + * Deprecated, use CUDA_ARRAY3D_LAYERED + */ +#define CUDA_ARRAY3D_2DARRAY 0x01 + +/** + * This flag must be set in order to bind a surface reference + * to the CUDA array + */ +#define CUDA_ARRAY3D_SURFACE_LDST 0x02 + +/** + * If set, the CUDA array is a collection of six 2D arrays, representing faces + * of a cube. The width of such a CUDA array must be equal to its height, and + * Depth must be six. If ::CUDA_ARRAY3D_LAYERED flag is also set, then the CUDA + * array is a collection of cubemaps and Depth must be a multiple of six. + */ +#define CUDA_ARRAY3D_CUBEMAP 0x04 + +/** + * This flag must be set in order to perform texture gather operations + * on a CUDA array. + */ +#define CUDA_ARRAY3D_TEXTURE_GATHER 0x08 + +/** + * This flag if set indicates that the CUDA + * array is a DEPTH_TEXTURE. + */ +#define CUDA_ARRAY3D_DEPTH_TEXTURE 0x10 + +/** + * This flag indicates that the CUDA array may be bound as a color target + * in an external graphics API + */ +#define CUDA_ARRAY3D_COLOR_ATTACHMENT 0x20 + +/** + * This flag if set indicates that the CUDA array or CUDA mipmapped array + * is a sparse CUDA array or CUDA mipmapped array respectively + */ +#define CUDA_ARRAY3D_SPARSE 0x40 + +/** + * This flag if set indicates that the CUDA array or CUDA mipmapped array + * will allow deferred memory mapping + */ +#define CUDA_ARRAY3D_DEFERRED_MAPPING 0x80 + +/** + * Override the texref format with a format inferred from the array. + * Flag for ::cuTexRefSetArray() + */ +#define CU_TRSA_OVERRIDE_FORMAT 0x01 + +/** + * Read the texture as integers rather than promoting the values to floats + * in the range [0,1]. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_READ_AS_INTEGER 0x01 + +/** + * Use normalized texture coordinates in the range [0,1) instead of [0,dim). + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_NORMALIZED_COORDINATES 0x02 + +/** + * Perform sRGB->linear conversion during texture read. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_SRGB 0x10 + +/** + * Disable any trilinear filtering optimizations. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION 0x20 + +/** + * Enable seamless cube map filtering. + * Flag for ::cuTexObjectCreate() + */ +#define CU_TRSF_SEAMLESS_CUBEMAP 0x40 + +/** + * C++ compile time constant for CU_LAUNCH_PARAM_END + */ +#define CU_LAUNCH_PARAM_END_AS_INT 0x00 + +/** + * End of array terminator for the \p extra parameter to + * ::cuLaunchKernel + */ +#define CU_LAUNCH_PARAM_END ((void *)CU_LAUNCH_PARAM_END_AS_INT) + +/** + * C++ compile time constant for CU_LAUNCH_PARAM_BUFFER_POINTER + */ +#define CU_LAUNCH_PARAM_BUFFER_POINTER_AS_INT 0x01 + +/** + * Indicator that the next value in the \p extra parameter to + * ::cuLaunchKernel will be a pointer to a buffer containing all kernel + * parameters used for launching kernel \p f. This buffer needs to + * honor all alignment/padding requirements of the individual parameters. + * If ::CU_LAUNCH_PARAM_BUFFER_SIZE is not also specified in the + * \p extra array, then ::CU_LAUNCH_PARAM_BUFFER_POINTER will have no + * effect. + */ +#define CU_LAUNCH_PARAM_BUFFER_POINTER \ + ((void *)CU_LAUNCH_PARAM_BUFFER_POINTER_AS_INT) + +/** + * C++ compile time constant for CU_LAUNCH_PARAM_BUFFER_SIZE + */ +#define CU_LAUNCH_PARAM_BUFFER_SIZE_AS_INT 0x02 + +/** + * Indicator that the next value in the \p extra parameter to + * ::cuLaunchKernel will be a pointer to a size_t which contains the + * size of the buffer specified with ::CU_LAUNCH_PARAM_BUFFER_POINTER. + * It is required that ::CU_LAUNCH_PARAM_BUFFER_POINTER also be specified + * in the \p extra array if the value associated with + * ::CU_LAUNCH_PARAM_BUFFER_SIZE is not zero. + */ +#define CU_LAUNCH_PARAM_BUFFER_SIZE ((void *)CU_LAUNCH_PARAM_BUFFER_SIZE_AS_INT) + +/** + * For texture references loaded into the module, use default texunit from + * texture reference. + */ +#define CU_PARAM_TR_DEFAULT -1 + +/** + * Device that represents the CPU + */ +#define CU_DEVICE_CPU ((CUdevice) - 1) + +/** + * Device that represents an invalid device + */ +#define CU_DEVICE_INVALID ((CUdevice) - 2) + +/** + * Bitmasks for ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS + */ +typedef enum CUflushGPUDirectRDMAWritesOptions_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_OPTION_HOST = + 1 << 0, /**< ::cuFlushGPUDirectRDMAWrites() and its CUDA Runtime API + counterpart are supported on the device. */ + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_OPTION_MEMOPS = + 1 << 1 /**< The ::CU_STREAM_WAIT_VALUE_FLUSH flag and the + ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on + the device. */ +} CUflushGPUDirectRDMAWritesOptions; + +/** + * Platform native ordering for GPUDirect RDMA writes + */ +typedef enum CUGPUDirectRDMAWritesOrdering_enum { + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_NONE = + 0, /**< The device does not natively support ordering of remote writes. + ::cuFlushGPUDirectRDMAWrites() can be leveraged if supported. */ + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_OWNER = + 100, /**< Natively, the device can consistently consume remote writes, + although other CUDA devices may not. */ + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_ALL_DEVICES = + 200 /**< Any CUDA device in the system can consistently consume remote + writes to this device. */ +} CUGPUDirectRDMAWritesOrdering; + +/** + * The scopes for ::cuFlushGPUDirectRDMAWrites + */ +typedef enum CUflushGPUDirectRDMAWritesScope_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER = + 100, /**< Blocks until remote writes are visible to the CUDA device + context owning the data. */ + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_ALL_DEVICES = + 200 /**< Blocks until remote writes are visible to all CUDA device + contexts. */ +} CUflushGPUDirectRDMAWritesScope; + +/** + * The targets for ::cuFlushGPUDirectRDMAWrites + */ +typedef enum CUflushGPUDirectRDMAWritesTarget_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX = + 0 /**< Sets the target for ::cuFlushGPUDirectRDMAWrites() to the currently + active CUDA device context. */ +} CUflushGPUDirectRDMAWritesTarget; + +/** + * The additional write options for ::cuGraphDebugDotPrint + */ +typedef enum CUgraphDebugDot_flags_enum { + CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE = + 1 << 0, /**< Output all debug data as if every debug flag is enabled */ + CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES = + 1 << 1, /**< Use CUDA Runtime structures for output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS = + 1 << 2, /**< Adds CUDA_KERNEL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS = + 1 << 3, /**< Adds CUDA_MEMCPY3D values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS = + 1 << 4, /**< Adds CUDA_MEMSET_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS = + 1 << 5, /**< Adds CUDA_HOST_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS = + 1 << 6, /**< Adds CUevent handle from record and wait nodes to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS = + 1 << 7, /**< Adds CUDA_EXT_SEM_SIGNAL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS = + 1 << 8, /**< Adds CUDA_EXT_SEM_WAIT_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES = + 1 << 9, /**< Adds CUkernelNodeAttrValue values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES = + 1 << 10, /**< Adds node handles and every kernel function handle to output + */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS = + 1 << 11, /**< Adds memory alloc node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS = + 1 << 12, /**< Adds memory free node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS = + 1 << 13 /**< Adds batch mem op node parameters to output */ + , + CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO = + 1 << 14 /**< Adds edge numbering information */ + , + CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS = + 1 << 15 /**< Adds conditional node parameters to output */ +} CUgraphDebugDot_flags; + +/** + * Flags for user objects for graphs + */ +typedef enum CUuserObject_flags_enum { + CU_USER_OBJECT_NO_DESTRUCTOR_SYNC = + 1 /**< Indicates the destructor execution is not synchronized by any CUDA + handle. */ +} CUuserObject_flags; + +/** + * Flags for retaining user object references for graphs + */ +typedef enum CUuserObjectRetain_flags_enum { + CU_GRAPH_USER_OBJECT_MOVE = 1 /**< Transfer references from the caller rather + than creating new references. */ +} CUuserObjectRetain_flags; + +/** + * Flags for instantiating a graph + */ +typedef enum CUgraphInstantiate_flags_enum { + CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH = + 1 /**< Automatically free memory allocated in a graph before relaunching. + */ + , + CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD = + 2 /**< Automatically upload the graph after instantiation. Only supported + by + ::cuGraphInstantiateWithParams. The upload will be performed using + the stream provided in \p instantiateParams. */ + , + CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH = + 4 /**< Instantiate the graph to be launchable from the device. This flag + can only be used on platforms which support unified addressing. This + flag cannot be used in conjunction with + CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH. */ + , + CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY = + 8 /**< Run the graph using the per-node priority attributes rather than + the priority of the stream it is launched into. */ +} CUgraphInstantiate_flags; + +typedef enum CUdeviceNumaConfig_enum { + CU_DEVICE_NUMA_CONFIG_NONE = 0, /**< The GPU is not a NUMA node */ + CU_DEVICE_NUMA_CONFIG_NUMA_NODE, /**< The GPU is a NUMA node, + CU_DEVICE_ATTRIBUTE_NUMA_ID contains its + NUMA ID */ +} CUdeviceNumaConfig; + +/** @} */ /* END CUDA_TYPES */ + +#if defined(__GNUC__) +#if defined(__CUDA_API_PUSH_VISIBILITY_DEFAULT) +#pragma GCC visibility push(default) +#endif +#endif + +#ifdef _WIN32 +#define CUDAAPI __stdcall +#else +#define CUDAAPI +#endif + +/** + * \defgroup CUDA_ERROR Error Handling + * + * ___MANBRIEF___ error handling functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the error handling functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Gets the string description of an error code + * + * Sets \p *pStr to the address of a NULL-terminated string description + * of the error code \p error. + * If the error code is not recognized, ::CUDA_ERROR_INVALID_VALUE + * will be returned and \p *pStr will be set to the NULL address. + * + * \param error - Error code to convert to string + * \param pStr - Address of the string pointer. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::CUresult, + * ::cudaGetErrorString + */ +CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr); + +/** + * \brief Gets the string representation of an error code enum name + * + * Sets \p *pStr to the address of a NULL-terminated string representation + * of the name of the enum error code \p error. + * If the error code is not recognized, ::CUDA_ERROR_INVALID_VALUE + * will be returned and \p *pStr will be set to the NULL address. + * + * \param error - Error code to convert to string + * \param pStr - Address of the string pointer. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::CUresult, + * ::cudaGetErrorName + */ +CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr); + +/** @} */ /* END CUDA_ERROR */ + +/** + * \defgroup CUDA_INITIALIZE Initialization + * + * ___MANBRIEF___ initialization functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the initialization functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Initialize the CUDA driver API + * Initializes the driver API and must be called before any other function from + * the driver API in the current process. Currently, the \p Flags parameter must + * be 0. If ::cuInit() has not been called, any function from the driver API + * will return + * ::CUDA_ERROR_NOT_INITIALIZED. + * + * \param Flags - Initialization flag for CUDA. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_SYSTEM_DRIVER_MISMATCH, + * ::CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE + * \notefnerr + */ +CUresult CUDAAPI cuInit(unsigned int Flags); + +/** @} */ /* END CUDA_INITIALIZE */ + +/** + * \defgroup CUDA_VERSION Version Management + * + * ___MANBRIEF___ version management functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the version management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns the latest CUDA version supported by driver + * + * Returns in \p *driverVersion the version of CUDA supported by + * the driver. The version is returned as + * (1000 × major + 10 × minor). For example, CUDA 9.2 + * would be represented by 9020. + * + * This function automatically returns ::CUDA_ERROR_INVALID_VALUE if + * \p driverVersion is NULL. + * + * \param driverVersion - Returns the CUDA driver version + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cudaDriverGetVersion, + * ::cudaRuntimeGetVersion + */ +CUresult CUDAAPI cuDriverGetVersion(int *driverVersion); + +/** @} */ /* END CUDA_VERSION */ + +/** + * \defgroup CUDA_DEVICE Device Management + * + * ___MANBRIEF___ device management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the device management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns a handle to a compute device + * + * Returns in \p *device a device handle given an ordinal in the range [0, + * ::cuDeviceGetCount()-1]. + * + * \param device - Returned device handle + * \param ordinal - Device number to get handle for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGetLuid, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport + */ +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal); + +/** + * \brief Returns the number of compute-capable devices + * + * Returns in \p *count the number of devices with compute capability greater + * than or equal to 2.0 that are available for execution. If there is no such + * device, ::cuDeviceGetCount() returns 0. + * + * \param count - Returned number of compute-capable devices + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceCount + */ +CUresult CUDAAPI cuDeviceGetCount(int *count); + +/** + * \brief Returns an identifier string for the device + * + * Returns an ASCII string identifying the device \p dev in the NULL-terminated + * string pointed to by \p name. \p len specifies the maximum length of the + * string that may be returned. + * + * \param name - Returned identifier string for the device + * \param len - Maximum length of string to store in \p name + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetUuid, + * ::cuDeviceGetLuid, + * ::cuDeviceGetCount, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev); + +/** + * \brief Return an UUID for the device + * + * Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will + * supplant this version in 12.0, which is retained for minor version + * compatibility. + * + * Returns 16-octets identifying the device \p dev in the structure + * pointed by the \p uuid. + * + * \param uuid - Returned UUID + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetUuid_v2 + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev); + +/** + * \brief Return an UUID for the device (11.4+) + * + * Returns 16-octets identifying the device \p dev in the structure + * pointed by the \p uuid. If the device is in MIG mode, returns its + * MIG UUID which uniquely identifies the subscribed MIG compute instance. + * + * \param uuid - Returned UUID + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetUuid_v2(CUuuid *uuid, CUdevice dev); + +/** + * \brief Return an LUID and device node mask for the device + * + * Return identifying information (\p luid and \p deviceNodeMask) to allow + * matching device with graphics APIs. + * + * \param luid - Returned LUID + * \param deviceNodeMask - Returned device node mask + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetLuid(char *luid, unsigned int *deviceNodeMask, + CUdevice dev); + +/** + * \brief Returns the total amount of memory on the device + * + * Returns in \p *bytes the total amount of memory available on the device + * \p dev in bytes. + * + * \param bytes - Returned memory available on device in bytes + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaMemGetInfo + */ +CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); + +/** + * \brief Returns the maximum number of elements allocatable in a 1D linear + * texture for a given texture element size. + * + * Returns in \p maxWidthInElements the maximum number of texture elements + * allocatable in a 1D linear texture for given \p format and \p numChannels. + * + * \param maxWidthInElements - Returned maximum number of texture elements + * allocatable for given \p format and \p numChannels. \param format - Texture + * format. \param numChannels - Number of channels per texture + * element. \param dev - Device handle. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cudaMemGetInfo, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, + CUarray_format format, + unsigned numChannels, + CUdevice dev); + +/** + * \brief Returns information about the device + * + * Returns in \p *pi the integer value of the attribute \p attrib on device + * \p dev. The supported attributes are: + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per + * block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of + * shared memory available to a thread block in bytes + * - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for + * __constant__ variables in a CUDA C kernel in bytes + * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads + * - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the + * memory copy functions that involve memory regions allocated through + * ::cuMemAllocPitch() + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D + * texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width + * for a 1D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum + * mipmapped 1D texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D + * texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D + * texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width + * for a 2D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height + * for a 2D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch + * in bytes for a 2D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum + * mipmapped 2D texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum + * mipmapped 2D texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D + * texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D + * texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D + * texture depth + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE: + * Alternate maximum 3D texture width, 0 if no alternate + * maximum 3D texture size is supported + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE: + * Alternate maximum 3D texture height, 0 if no alternate + * maximum 3D texture size is supported + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE: + * Alternate maximum 3D texture depth, 0 if no alternate + * maximum 3D texture size is supported + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH: + * Maximum cubemap texture width or height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH: + * Maximum 1D layered texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS: + * Maximum layers in a 1D layered texture + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH: + * Maximum 2D layered texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT: + * Maximum 2D layered texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS: + * Maximum layers in a 2D layered texture + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH: + * Maximum cubemap layered texture width or height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS: + * Maximum layers in a cubemap layered texture + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH: + * Maximum 1D surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH: + * Maximum 2D surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT: + * Maximum 2D surface height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH: + * Maximum 3D surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT: + * Maximum 3D surface height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH: + * Maximum 3D surface depth + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH: + * Maximum 1D layered surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS: + * Maximum layers in a 1D layered surface + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH: + * Maximum 2D layered surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT: + * Maximum 2D layered surface height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS: + * Maximum layers in a 2D layered surface + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH: + * Maximum cubemap surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH: + * Maximum cubemap layered surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS: + * Maximum layers in a cubemap layered surface + * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit + * registers available to a thread block + * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz + * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture + * base addresses aligned to ::textureAlign bytes do not need an offset + * applied to texture fetches + * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement + * for 2D texture references bound to pitched memory + * - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy + * memory between host and device while executing a kernel, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on + * the device + * - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit + * for kernels executed on the device, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the + * memory subsystem, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host + * memory into the CUDA address space, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently + * in. Available modes are as follows: + * - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and + * can have multiple CUDA contexts present at a single time. + * - ::CU_COMPUTEMODE_PROHIBITED: Compute-prohibited mode - Device is + * prohibited from creating new CUDA contexts. + * - ::CU_COMPUTEMODE_EXCLUSIVE_PROCESS: Compute-exclusive-process mode - + * Device can have only one context used by a single process at a time. + * - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS: 1 if the device supports + * executing multiple kernels within the same context simultaneously, or 0 if + * not. It is not guaranteed that multiple kernels will be resident + * on the device concurrently so this feature should not be relied upon for + * correctness. + * - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the + * device, 0 if error correction is disabled or not supported by the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) + * identifier of the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device + * - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. + * TCC is only available on Tesla hardware running Windows Vista or later + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in + * kilohertz + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in + * bits + * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the + * device doesn't have L2 cache + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident + * threads per multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified + * address space with the host, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability + * version number + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability + * version number + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports + * caching globals in L1 cache, 0 if caching globals in L1 cache is not + * supported by the device + * - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports + * caching locals in L1 cache, 0 if caching locals in L1 cache is not supported + * by the device + * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount + * of shared memory available to a multiprocessor in bytes; this amount is + * shared by all thread blocks simultaneously resident on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of + * 32-bit registers available to a multiprocessor; this number is shared by all + * thread blocks simultaneously resident on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating + * managed memory on this system, 0 if allocating managed memory is not + * supported by the device on this system. + * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, + * 0 if not. + * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID: Unique identifier for a + * group of devices associated with the same board. Devices on the same + * multi-GPU board will share the same identifier. + * - ::CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED: 1 if Link between the + * device and the host supports native atomic operations. + * - ::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO: Ratio of + * single precision performance (in floating-point operations per second) to + * double precision performance. + * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device supports coherently + * accessing pageable memory without calling cudaHostRegister on it. + * - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS: Device can coherently + * access managed memory concurrently with the CPU. + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED: Device supports Compute + * Preemption. + * - ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM: Device can + * access host registered memory at the same virtual address as the CPU. + * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN: The maximum per + * block shared memory size supported on this device. This is the maximum value + * that can be opted into when using the cuFuncSetAttribute() or + * cuKernelSetAttribute() call. For more details see + * ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES + * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device + * accesses pageable memory via the host's page tables. + * - ::CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST: The host can + * directly access managed memory on the device without migration. + * - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports + * virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, + * ::cuMemMap and related APIs + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device + * supports exporting memory to a posix file descriptor with + * ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports + * exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if + * requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device + * supports exporting memory to a Win32 KMT handle with + * ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of + * thread blocks that can reside on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports + * compressible memory allocation via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting + * lines capacity setting in bytes + * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of + * CUaccessPolicyWindow::num_bytes + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device + * supports specifying the GPUDirect RDMA flag with ::cuMemCreate. + * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared + * memory per block reserved by CUDA driver in bytes + * - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse + * CUDA arrays and sparse CUDA mipmapped arrays. + * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports + * using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register + * memory that must be mapped as read-only to the GPU + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the + * ::cuMemAllocAsync and ::cuMemPool family of APIs + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect + * RDMA APIs, like nvidia_p2p_get_pages (see + * https://docs.nvidia.com/cuda/gpudirect-rdma for more information) + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned + * attribute shall be interpreted as a bitmask, where the individual bits are + * described by the ::CUflushGPUDirectRDMAWritesOptions enum + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA + * writes to the device do not need to be flushed for consumers within the scope + * indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for + * the numerical values returned here. + * - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle + * types supported with mempool based IPC + * - ::CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED: Device + * supports deferred mapping CUDA arrays and CUDA mipmapped arrays. + * + * \param pi - Returned device attribute value + * \param attrib - Device attribute to query + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaDeviceGetAttribute, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, + CUdevice dev); + +/** + * \brief Return NvSciSync attributes that this device can support. + * + * Returns in \p nvSciSyncAttrList, the properties of NvSciSync that + * this CUDA device, \p dev can support. The returned \p nvSciSyncAttrList + * can be used to create an NvSciSync object that matches this device's + * capabilities. + * + * If NvSciSyncAttrKey_RequiredPerm field in \p nvSciSyncAttrList is + * already set this API will return ::CUDA_ERROR_INVALID_VALUE. + * + * The applications should set \p nvSciSyncAttrList to a valid + * NvSciSyncAttrList failing which this API will return + * ::CUDA_ERROR_INVALID_HANDLE. + * + * The \p flags controls how applications intends to use + * the NvSciSync created from the \p nvSciSyncAttrList. The valid flags are: + * - ::CUDA_NVSCISYNC_ATTR_SIGNAL, specifies that the applications intends to + * signal an NvSciSync on this CUDA device. + * - ::CUDA_NVSCISYNC_ATTR_WAIT, specifies that the applications intends to + * wait on an NvSciSync on this CUDA device. + * + * At least one of these flags must be set, failing which the API + * returns ::CUDA_ERROR_INVALID_VALUE. Both the flags are orthogonal + * to one another: a developer may set both these flags that allows to + * set both wait and signal specific attributes in the same \p + * nvSciSyncAttrList. + * + * Note that this API updates the input \p nvSciSyncAttrList with values + * equivalent to the following public attribute key-values: + * NvSciSyncAttrKey_RequiredPerm is set to + * - NvSciSyncAccessPerm_SignalOnly if ::CUDA_NVSCISYNC_ATTR_SIGNAL is set in \p + * flags. + * - NvSciSyncAccessPerm_WaitOnly if ::CUDA_NVSCISYNC_ATTR_WAIT is set in \p + * flags. + * - NvSciSyncAccessPerm_WaitSignal if both ::CUDA_NVSCISYNC_ATTR_WAIT and + * ::CUDA_NVSCISYNC_ATTR_SIGNAL are set in \p flags. + * NvSciSyncAttrKey_PrimitiveInfo is set to + * - NvSciSyncAttrValPrimitiveType_SysmemSemaphore on any valid \p device. + * - NvSciSyncAttrValPrimitiveType_Syncpoint if \p device is a Tegra device. + * - NvSciSyncAttrValPrimitiveType_SysmemSemaphorePayload64b if \p device is + * GA10X+. NvSciSyncAttrKey_GpuId is set to the same UUID that is returned for + * this \p device from ::cuDeviceGetUuid. + * + * \param nvSciSyncAttrList - Return NvSciSync attributes supported. + * \param dev - Valid Cuda Device to get NvSciSync attributes + * for. \param flags - flags describing NvSciSync usage. + * + * \return + * + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa + * ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuDeviceGetNvSciSyncAttributes(void *nvSciSyncAttrList, + CUdevice dev, int flags); + +/** + * \brief Sets the current memory pool of a device + * + * The memory pool must be local to the specified device. + * ::cuMemAllocAsync allocates from the current mempool of the provided stream's + * device. By default, a device's current memory pool is its default memory + * pool. + * + * \note Use ::cuMemAllocFromPoolAsync to specify asynchronous allocations from + * a device different than the one the stream runs on. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolDestroy, ::cuMemAllocFromPoolAsync + */ +CUresult CUDAAPI cuDeviceSetMemPool(CUdevice dev, CUmemoryPool pool); + +/** + * \brief Gets the current mempool for a device + * + * Returns the last pool provided to ::cuDeviceSetMemPool for this device + * or the device's default memory pool if ::cuDeviceSetMemPool has never been + * called. By default the current mempool is the default mempool for a device. + * Otherwise the returned pool must have been set with ::cuDeviceSetMemPool. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuDeviceGetDefaultMemPool, ::cuMemPoolCreate, ::cuDeviceSetMemPool + */ +CUresult CUDAAPI cuDeviceGetMemPool(CUmemoryPool *pool, CUdevice dev); + +/** + * \brief Returns the default mempool of a device + * + * The default mempool of a device contains device memory from that device. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuMemAllocAsync, ::cuMemPoolTrimTo, ::cuMemPoolGetAttribute, + * ::cuMemPoolSetAttribute, cuMemPoolSetAccess, ::cuDeviceGetMemPool, + * ::cuMemPoolCreate + */ +CUresult CUDAAPI cuDeviceGetDefaultMemPool(CUmemoryPool *pool_out, + CUdevice dev); + +/** + * \brief Returns information about the execution affinity support of the + * device. + * + * Returns in \p *pi whether execution affinity type \p type is supported by + * device \p dev. The supported types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: 1 if context with limited SMs is + * supported by the device, or 0 if not; + * + * \param pi - 1 if the execution affinity type \p type is supported by the + * device, or 0 if not \param type - Execution affinity type to query \param dev + * - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, + CUexecAffinityType type, + CUdevice dev); + +/** + * \brief Blocks until remote writes are visible to the specified scope + * + * Blocks until GPUDirect RDMA writes to the target context via mappings + * created through APIs like nvidia_p2p_get_pages (see + * https://docs.nvidia.com/cuda/gpudirect-rdma for more information), are + * visible to the specified scope. + * + * If the scope equals or lies within the scope indicated by + * ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING, the call + * will be a no-op and can be safely omitted for performance. This can be + * determined by comparing the numerical values between the two enums, with + * smaller scopes having smaller values. + * + * Users may query support for this API via + * ::CU_DEVICE_ATTRIBUTE_FLUSH_FLUSH_GPU_DIRECT_RDMA_OPTIONS. + * + * \param target - The target of the operation, see + * ::CUflushGPUDirectRDMAWritesTarget \param scope - The scope of the + * operation, see ::CUflushGPUDirectRDMAWritesScope + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + */ +CUresult CUDAAPI +cuFlushGPUDirectRDMAWrites(CUflushGPUDirectRDMAWritesTarget target, + CUflushGPUDirectRDMAWritesScope scope); + +/** @} */ /* END CUDA_DEVICE */ + +/** + * \defgroup CUDA_DEVICE_DEPRECATED Device Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated device management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the device management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns properties for a selected device + * + * \deprecated + * + * This function was deprecated as of CUDA 5.0 and replaced by + ::cuDeviceGetAttribute(). + * + * Returns in \p *prop the properties of device \p dev. The ::CUdevprop + * structure is defined as: + * + * \code + typedef struct CUdevprop_st { + int maxThreadsPerBlock; + int maxThreadsDim[3]; + int maxGridSize[3]; + int sharedMemPerBlock; + int totalConstantMemory; + int SIMDWidth; + int memPitch; + int regsPerBlock; + int clockRate; + int textureAlign + } CUdevprop; + * \endcode + * where: + * + * - ::maxThreadsPerBlock is the maximum number of threads per block; + * - ::maxThreadsDim[3] is the maximum sizes of each dimension of a block; + * - ::maxGridSize[3] is the maximum sizes of each dimension of a grid; + * - ::sharedMemPerBlock is the total amount of shared memory available per + * block in bytes; + * - ::totalConstantMemory is the total amount of constant memory available on + * the device in bytes; + * - ::SIMDWidth is the warp size; + * - ::memPitch is the maximum pitch allowed by the memory copy functions that + * involve memory regions allocated through ::cuMemAllocPitch(); + * - ::regsPerBlock is the total number of registers available per block; + * - ::clockRate is the clock frequency in kilohertz; + * - ::textureAlign is the alignment requirement; texture base addresses that + * are aligned to ::textureAlign bytes do not need an offset applied to + * texture fetches. + * + * \param prop - Returned properties of device + * \param dev - Device to get properties for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, + CUdevice dev); + +/** + * \brief Returns the compute capability of the device + * + * \deprecated + * + * This function was deprecated as of CUDA 5.0 and its functionality superseded + * by ::cuDeviceGetAttribute(). + * + * Returns in \p *major and \p *minor the major and minor revision numbers that + * define the compute capability of the device \p dev. + * + * \param major - Major revision number + * \param minor - Minor revision number + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceComputeCapability(int *major, + int *minor, + CUdevice dev); + +/** @} */ /* END CUDA_DEVICE_DEPRECATED */ + +/** + * \defgroup CUDA_PRIMARY_CTX Primary Context Management + * + * ___MANBRIEF___ primary context management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the primary context management functions of the + * low-level CUDA driver application programming interface. + * + * The primary context is unique per device and shared with the CUDA runtime + * API. These functions allow integration with other libraries using CUDA. + * + * @{ + */ + +/** + * \brief Retain the primary context on the GPU + * + * Retains the primary context on the device. + * Once the user successfully retains the primary context, the primary context + * will be active and available to the user until the user releases it + * with ::cuDevicePrimaryCtxRelease() or resets it with + * ::cuDevicePrimaryCtxReset(). Unlike ::cuCtxCreate() the newly retained + * context is not pushed onto the stack. + * + * Retaining the primary context for the first time will fail with + * ::CUDA_ERROR_UNKNOWN if the compute mode of the device is + * ::CU_COMPUTEMODE_PROHIBITED. The function + * ::cuDeviceGetAttribute() can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE + * to determine the compute mode of the device. The nvidia-smi tool can + * be used to set the compute mode for devices. Documentation for + * nvidia-smi can be obtained by passing a -h option to it. + * + * Please note that the primary context always supports pinned allocations. + * Other flags can be specified by ::cuDevicePrimaryCtxSetFlags(). + * + * \param pctx - Returned context handle of the new context + * \param dev - Device for which primary context is requested + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRelease, + * ::cuDevicePrimaryCtxSetFlags, + * ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev); + +/** + * \brief Release the primary context on the GPU + * + * Releases the primary context interop on the device. + * A retained context should always be released once the user is done using + * it. The context is automatically reset once the last reference to it is + * released. This behavior is different when the primary context was retained + * by the CUDA runtime from CUDA 4.0 and earlier. In this case, the primary + * context remains always active. + * + * Releasing a primary context that has not been previously retained will + * fail with ::CUDA_ERROR_INVALID_CONTEXT. + * + * Please note that unlike ::cuCtxDestroy() this method does not pop the context + * from stack in any circumstances. + * + * \param dev - Device which primary context is released + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRetain, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); + +/** + * \brief Set flags for the primary context + * + * Sets the flags for the primary context on the device overwriting previously + * set ones. + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * + * - ::CU_CTX_COREDUMP_ENABLE: If GPU coredumps have not been enabled globally + * with ::cuCoredumpSetAttributeGlobal or environment variables, this flag can + * be set during context creation to instruct CUDA to create a coredump if + * this context raises an exception during execution. These environment + * variables are described in the CUDA-GDB user guide under the "GPU core dump + * support" section. The initial settings will be taken from the global settings + * at the time of context creation. The other settings that control coredump + * output can be modified by calling ::cuCoredumpSetAttribute from the created + * context after it becomes current. + * + * - ::CU_CTX_USER_COREDUMP_ENABLE: If user-triggered GPU coredumps have not + * been enabled globally with ::cuCoredumpSetAttributeGlobal or environment + * variables, this flag can be set during context creation to instruct CUDA to + * create a coredump if data is written to a certain pipe that is present in the + * OS space. These environment variables are described in the CUDA-GDB user + * guide under the "GPU core dump support" section. + * It is important to note that the pipe name *must* be set with + * ::cuCoredumpSetAttributeGlobal before creating the context if this flag is + * used. Setting this flag implies that ::CU_CTX_COREDUMP_ENABLE is set. + * The initial settings will be taken from the global settings at the time of + * context creation. The other settings that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * + * - ::CU_CTX_SYNC_MEMOPS: Ensures that synchronous memory operations initiated + * on this context will always synchronize. See further documentation in the + * section titled "API Synchronization behavior" to learn more about cases when + * synchronous memory operations can exhibit asynchronous behavior. + * + * \param dev - Device for which the primary context flags are set + * \param flags - New flags for the device + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRetain, + * ::cuDevicePrimaryCtxGetState, + * ::cuCtxCreate, + * ::cuCtxGetFlags, + * ::cuCtxSetFlags, + * ::cudaSetDeviceFlags + */ +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); + +/** + * \brief Get the state of the primary context + * + * Returns in \p *flags the flags for the primary context of \p dev, and in + * \p *active whether it is active. See ::cuDevicePrimaryCtxSetFlags for flag + * values. + * + * \param dev - Device to get primary context flags for + * \param flags - Pointer to store flags + * \param active - Pointer to store context state; 0 = inactive, 1 = active + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa + * ::cuDevicePrimaryCtxSetFlags, + * ::cuCtxGetFlags, + * ::cuCtxSetFlags, + * ::cudaGetDeviceFlags + */ +CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, + int *active); + +/** + * \brief Destroy all allocations and reset all state on the primary context + * + * Explicitly destroys and cleans up all resources associated with the current + * device in the current process. + * + * Note that it is responsibility of the calling function to ensure that no + * other module in the process is using the device any more. For that reason + * it is recommended to use ::cuDevicePrimaryCtxRelease() in most cases. + * However it is safe for other modules to call ::cuDevicePrimaryCtxRelease() + * even after resetting the device. + * Resetting the primary context does not release it, an application that has + * retained the primary context should explicitly release its usage. + * + * \param dev - Device for which primary context is destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRetain, + * ::cuDevicePrimaryCtxRelease, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaDeviceReset + */ +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); + +/** @} */ /* END CUDA_PRIMARY_CTX */ + +/** + * \defgroup CUDA_CTX Context Management + * + * ___MANBRIEF___ context management functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the context management functions of the low-level + * CUDA driver application programming interface. + * + * Please note that some functions are described in + * \ref CUDA_PRIMARY_CTX "Primary Context Management" section. + * + * @{ + */ + +/** + * \brief Create a CUDA context + * + * \note In most cases it is recommended to use ::cuDevicePrimaryCtxRetain. + * + * Creates a new CUDA context and associates it with the calling thread. The + * \p flags parameter is described below. The context is created with a usage + * count of 1 and the caller of ::cuCtxCreate() must call ::cuCtxDestroy() + * when done using the context. If a context is already current to the thread, + * it is supplanted by the newly created context and may be restored by a + * subsequent call to ::cuCtxPopCurrent(). + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_MAP_HOST: Instruct CUDA to support mapped pinned allocations. + * This flag must be set in order to allocate pinned host memory that is + * accessible to the GPU. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). + * + * - ::CU_CTX_COREDUMP_ENABLE: If GPU coredumps have not been enabled globally + * with ::cuCoredumpSetAttributeGlobal or environment variables, this flag can + * be set during context creation to instruct CUDA to create a coredump if + * this context raises an exception during execution. These environment + * variables are described in the CUDA-GDB user guide under the "GPU core dump + * support" section. The initial attributes will be taken from the global + * attributes at the time of context creation. The other attributes that control + * coredump output can be modified by calling ::cuCoredumpSetAttribute from the + * created context after it becomes current. + * + * - ::CU_CTX_USER_COREDUMP_ENABLE: If user-triggered GPU coredumps have not + * been enabled globally with ::cuCoredumpSetAttributeGlobal or environment + * variables, this flag can be set during context creation to instruct CUDA to + * create a coredump if data is written to a certain pipe that is present in the + * OS space. These environment variables are described in the CUDA-GDB user + * guide under the "GPU core dump support" section. + * It is important to note that the pipe name *must* be set with + * ::cuCoredumpSetAttributeGlobal before creating the context if this flag is + * used. Setting this flag implies that ::CU_CTX_COREDUMP_ENABLE is set. + * The initial attributes will be taken from the global attributes at the time + * of context creation. The other attributes that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * Setting this flag on any context creation is equivalent to setting the + * ::CU_COREDUMP_ENABLE_USER_TRIGGER attribute to \p true globally. + * + * - ::CU_CTX_SYNC_MEMOPS: Ensures that synchronous memory operations initiated + * on this context will always synchronize. See further documentation in the + * section titled "API Synchronization behavior" to learn more about cases when + * synchronous memory operations can exhibit asynchronous behavior. + * + * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of + * the device is ::CU_COMPUTEMODE_PROHIBITED. The function + * ::cuDeviceGetAttribute() can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE + * to determine the compute mode of the device. The nvidia-smi tool can + * be used to set the compute mode for * devices. Documentation for + * nvidia-smi can be obtained by passing a -h option to it. + * + * \param pctx - Returned context handle of the new context + * \param flags - Context creation flags + * \param dev - Device to create context on + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCoredumpSetAttributeGlobal, + * ::cuCoredumpSetAttribute, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); + +/** + * \brief Create a CUDA context with execution affinity + * + * Creates a new CUDA context with execution affinity and associates it with + * the calling thread. The \p paramsArray and \p flags parameter are described + * below. The context is created with a usage count of 1 and the caller of + * ::cuCtxCreate() must call ::cuCtxDestroy() when done using the context. If a + * context is already current to the thread, it is supplanted by the newly + * created context and may be restored by a subsequent call to + * ::cuCtxPopCurrent(). + * + * The type and the amount of execution resource the context can use is limited + * by \p paramsArray and \p numParams. The \p paramsArray is an array of \p + * CUexecAffinityParam and the \p numParams describes the size of the array. If + * two \p CUexecAffinityParam in the array have the same type, the latter + * execution affinity parameter overrides the former execution affinity + * parameter. The supported execution affinity types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT limits the portion of SMs that the context + * can use. The portion of SMs is specified as the number of SMs via \p + * CUexecAffinitySmCount. This limit will be internally rounded up to the next + * hardware-supported amount. Hence, it is imperative to query the actual + * execution affinity of the context via \p cuCtxGetExecAffinity after context + * creation. Currently, this attribute is only supported under Volta+ MPS. + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_MAP_HOST: Instruct CUDA to support mapped pinned allocations. + * This flag must be set in order to allocate pinned host memory that is + * accessible to the GPU. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). + * + * - ::CU_CTX_COREDUMP_ENABLE: If GPU coredumps have not been enabled globally + * with ::cuCoredumpSetAttributeGlobal or environment variables, this flag can + * be set during context creation to instruct CUDA to create a coredump if + * this context raises an exception during execution. These environment + * variables are described in the CUDA-GDB user guide under the "GPU core dump + * support" section. The initial attributes will be taken from the global + * attributes at the time of context creation. The other attributes that control + * coredump output can be modified by calling ::cuCoredumpSetAttribute from the + * created context after it becomes current. + * + * - ::CU_CTX_USER_COREDUMP_ENABLE: If user-triggered GPU coredumps have not + * been enabled globally with ::cuCoredumpSetAttributeGlobal or environment + * variables, this flag can be set during context creation to instruct CUDA to + * create a coredump if data is written to a certain pipe that is present in the + * OS space. These environment variables are described in the CUDA-GDB user + * guide under the "GPU core dump support" section. + * It is important to note that the pipe name *must* be set with + * ::cuCoredumpSetAttributeGlobal before creating the context if this flag is + * used. Setting this flag implies that ::CU_CTX_COREDUMP_ENABLE is set. + * The initial attributes will be taken from the global attributes at the time + * of context creation. The other attributes that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * Setting this flag on any context creation is equivalent to setting the + * ::CU_COREDUMP_ENABLE_USER_TRIGGER attribute to \p true globally. + * + * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of + * the device is ::CU_COMPUTEMODE_PROHIBITED. The function + * ::cuDeviceGetAttribute() can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE + * to determine the compute mode of the device. The nvidia-smi tool can + * be used to set the compute mode for * devices. Documentation for + * nvidia-smi can be obtained by passing a -h option to it. + * + * \param pctx - Returned context handle of the new context + * \param paramsArray - Execution affinity parameters + * \param numParams - Number of execution affinity parameters + * \param flags - Context creation flags + * \param dev - Device to create context on + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuCoredumpSetAttributeGlobal, + * ::cuCoredumpSetAttribute, + * ::CUexecAffinityParam + */ +CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, + CUexecAffinityParam *paramsArray, int numParams, + unsigned int flags, CUdevice dev); + +/** + * \brief Destroy a CUDA context + * + * Destroys the CUDA context specified by \p ctx. The context \p ctx will be + * destroyed regardless of how many threads it is current to. + * It is the responsibility of the calling function to ensure that no API + * call issues using \p ctx while ::cuCtxDestroy() is executing. + * + * Destroys and cleans up all resources associated with the context. + * It is the caller's responsibility to ensure that the context or its resources + * are not accessed or passed in subsequent API calls and doing so will result + * in undefined behavior. These resources include CUDA types such as ::CUmodule, + * ::CUfunction, ::CUstream, ::CUevent, + * ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, + * ::CUsurfref, + * ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and + * ::CUexternalSemaphore. + * + * If \p ctx is current to the calling thread then \p ctx will also be + * popped from the current thread's context stack (as though ::cuCtxPopCurrent() + * were called). If \p ctx is current to other threads, then \p ctx will + * remain current to those threads, and attempting to access \p ctx from + * those threads will result in the error ::CUDA_ERROR_CONTEXT_IS_DESTROYED. + * + * \param ctx - Context to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); + +/** + * \brief Pushes a context on the current CPU thread + * + * Pushes the given context \p ctx onto the CPU thread's stack of current + * contexts. The specified context becomes the CPU thread's current context, so + * all CUDA functions that operate on the current context are affected. + * + * The previous current context may be made current again by calling + * ::cuCtxDestroy() or ::cuCtxPopCurrent(). + * + * \param ctx - Context to push + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx); + +/** + * \brief Pops the current CUDA context from the current CPU thread. + * + * Pops the current CUDA context from the CPU thread and passes back the + * old context handle in \p *pctx. That context may then be made current + * to a different CPU thread by calling ::cuCtxPushCurrent(). + * + * If a context was current to the CPU thread before ::cuCtxCreate() or + * ::cuCtxPushCurrent() was called, this function makes that context current to + * the CPU thread again. + * + * \param pctx - Returned popped context handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx); + +/** + * \brief Binds the specified CUDA context to the calling CPU thread + * + * Binds the specified CUDA context to the calling CPU thread. + * If \p ctx is NULL then the CUDA context previously bound to the + * calling CPU thread is unbound and ::CUDA_SUCCESS is returned. + * + * If there exists a CUDA context stack on the calling CPU thread, this + * will replace the top of that stack with \p ctx. + * If \p ctx is NULL then this will be equivalent to popping the top + * of the calling CPU thread's CUDA context stack (or a no-op if the + * calling CPU thread's CUDA context stack is empty). + * + * \param ctx - Context to bind to the calling CPU thread + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa + * ::cuCtxGetCurrent, + * ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cudaSetDevice + */ +CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx); + +/** + * \brief Returns the CUDA context bound to the calling CPU thread. + * + * Returns in \p *pctx the CUDA context bound to the calling CPU thread. + * If no context is bound to the calling CPU thread then \p *pctx is + * set to NULL and ::CUDA_SUCCESS is returned. + * + * \param pctx - Returned context handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * \notefnerr + * + * \sa + * ::cuCtxSetCurrent, + * ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cudaGetDevice + */ +CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx); + +/** + * \brief Returns the device ID for the current context + * + * Returns in \p *device the ordinal of the current context's device. + * + * \param device - Returned device ID for the current context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaGetDevice + */ +CUresult CUDAAPI cuCtxGetDevice(CUdevice *device); + +/** + * \brief Returns the flags for the current context + * + * Returns in \p *flags the flags of the current context. See ::cuCtxCreate + * for flag values. + * + * \param flags - Pointer to store flags of current context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetCurrent, + * ::cuCtxGetDevice, + * ::cuCtxGetLimit, + * ::cuCtxGetSharedMemConfig, + * ::cuCtxGetStreamPriorityRange, + * ::cuCtxSetFlags, + * ::cudaGetDeviceFlags + */ +CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags); + +/** + * \brief Sets the flags for the current context + * + * Sets the flags for the current context overwriting previously set ones. See + * ::cuDevicePrimaryCtxSetFlags for flag values. + * + * \param flags - Flags to set on the current context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetCurrent, + * ::cuCtxGetDevice, + * ::cuCtxGetLimit, + * ::cuCtxGetSharedMemConfig, + * ::cuCtxGetStreamPriorityRange, + * ::cuCtxGetFlags, + * ::cudaGetDeviceFlags, + * ::cuDevicePrimaryCtxSetFlags, + */ +CUresult CUDAAPI cuCtxSetFlags(unsigned int flags); + +/** + * \brief Returns the unique Id associated with the context supplied + * + * Returns in \p ctxId the unique Id which is associated with a given context. + * The Id is unique for the life of the program for this instance of CUDA. + * If context is supplied as NULL and there is one current, the Id of the + * current context is returned. + * + * \param ctx - Context for which to obtain the Id + * \param ctxId - Pointer to store the Id of the context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPushCurrent + */ +CUresult CUDAAPI cuCtxGetId(CUcontext ctx, unsigned long long *ctxId); + +/** + * \brief Block for a context's tasks to complete + * + * Blocks until the device has completed all preceding requested tasks. + * ::cuCtxSynchronize() returns an error if one of the preceding tasks failed. + * If the context was created with the ::CU_CTX_SCHED_BLOCKING_SYNC flag, the + * CPU thread will block until the GPU context has finished its work. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cudaDeviceSynchronize + */ +CUresult CUDAAPI cuCtxSynchronize(void); + +/** + * \brief Set resource limits + * + * Setting \p limit to \p value is a request by the application to update + * the current limit maintained by the context. The driver is free to + * modify the requested value to meet h/w requirements (this could be + * clamping to minimum or maximum values, rounding up to nearest element + * size, etc). The application can use ::cuCtxGetLimit() to find out exactly + * what the limit has been set to. + * + * Setting each ::CUlimit has its own specific restrictions, so each is + * discussed here. + * + * - ::CU_LIMIT_STACK_SIZE controls the stack size in bytes of each GPU thread. + * The driver automatically increases the per-thread stack size + * for each kernel launch as needed. This size isn't reset back to the + * original value after each launch. Setting this value will take effect + * immediately, and if necessary, the device will block until all preceding + * requested tasks are complete. + * + * - ::CU_LIMIT_PRINTF_FIFO_SIZE controls the size in bytes of the FIFO used + * by the ::printf() device system call. Setting ::CU_LIMIT_PRINTF_FIFO_SIZE + * must be performed before launching any kernel that uses the ::printf() + * device system call, otherwise ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * - ::CU_LIMIT_MALLOC_HEAP_SIZE controls the size in bytes of the heap used + * by the ::malloc() and ::free() device system calls. Setting + * ::CU_LIMIT_MALLOC_HEAP_SIZE must be performed before launching any kernel + * that uses the ::malloc() or ::free() device system calls, otherwise + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * - ::CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH controls the maximum nesting depth of + * a grid at which a thread can safely call ::cudaDeviceSynchronize(). Setting + * this limit must be performed before any launch of a kernel that uses the + * device runtime and calls ::cudaDeviceSynchronize() above the default sync + * depth, two levels of grids. Calls to ::cudaDeviceSynchronize() will fail + * with error code ::cudaErrorSyncDepthExceeded if the limitation is + * violated. This limit can be set smaller than the default or up the maximum + * launch depth of 24. When setting this limit, keep in mind that additional + * levels of sync depth require the driver to reserve large amounts of device + * memory which can no longer be used for user allocations. If these + * reservations of device memory fail, ::cuCtxSetLimit() will return + * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. + * This limit is only applicable to devices of compute capability < 9.0. + * Attempting to set this limit on devices of other compute capability + * versions will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being + * returned. + * + * - ::CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT controls the maximum number of + * outstanding device runtime launches that can be made from the current + * context. A grid is outstanding from the point of launch up until the grid + * is known to have been completed. Device runtime launches which violate + * this limitation fail and return ::cudaErrorLaunchPendingCountExceeded when + * ::cudaGetLastError() is called after launch. If more pending launches than + * the default (2048 launches) are needed for a module using the device + * runtime, this limit can be increased. Keep in mind that being able to + * sustain additional pending launches will require the driver to reserve + * larger amounts of device memory upfront which can no longer be used for + * allocations. If these reservations fail, ::cuCtxSetLimit() will return + * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. + * This limit is only applicable to devices of compute capability 3.5 and + * higher. Attempting to set this limit on devices of compute capability less + * than 3.5 will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being + * returned. + * + * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch + * granularity. Values can range from 0B to 128B. This is purely a performance + * hint and it can be ignored or clamped depending on the platform. + * + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes available for + * persisting L2 cache. This is purely a performance hint and it can be + * ignored or clamped depending on the platform. + * + * \param limit - Limit to set + * \param value - Size of limit + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_LIMIT, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSynchronize, + * ::cudaDeviceSetLimit + */ +CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value); + +/** + * \brief Returns resource limits + * + * Returns in \p *pvalue the current size of \p limit. The supported + * ::CUlimit values are: + * - ::CU_LIMIT_STACK_SIZE: stack size in bytes of each GPU thread. + * - ::CU_LIMIT_PRINTF_FIFO_SIZE: size in bytes of the FIFO used by the + * ::printf() device system call. + * - ::CU_LIMIT_MALLOC_HEAP_SIZE: size in bytes of the heap used by the + * ::malloc() and ::free() device system calls. + * - ::CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH: maximum grid depth at which a thread + * can issue the device runtime call ::cudaDeviceSynchronize() to wait on + * child grid launches to complete. + * - ::CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT: maximum number of outstanding + * device runtime launches that can be made from this context. + * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY: L2 cache fetch granularity. + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE: Persisting L2 cache size in bytes + * + * \param limit - Limit to query + * \param pvalue - Returned size of limit + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_LIMIT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaDeviceGetLimit + */ +CUresult CUDAAPI cuCtxGetLimit(size_t *pvalue, CUlimit limit); + +/** + * \brief Returns the preferred cache configuration for the current context. + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this function returns through \p pconfig the preferred cache + * configuration for the current context. This is only a preference. The driver + * will use the requested configuration if possible, but it is free to choose a + * different configuration if required to execute functions. + * + * This will return a \p pconfig of ::CU_FUNC_CACHE_PREFER_NONE on devices + * where the size of the L1 cache and shared memory are fixed. + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 + * (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 + * cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \param pconfig - Returned cache configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceGetCacheConfig + */ +CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig); + +/** + * \brief Sets the preferred cache configuration for the current context. + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets through \p config the preferred cache configuration for + * the current context. This is only a preference. The driver will use + * the requested configuration if possible, but it is free to choose a different + * configuration if required to execute the function. Any function preference + * set via ::cuFuncSetCacheConfig() or ::cuKernelSetCacheConfig() will be + * preferred over this context-wide setting. Setting the context-wide cache + * configuration to + * ::CU_FUNC_CACHE_PREFER_NONE will cause subsequent kernel launches to prefer + * to not change the cache configuration unless required to launch the kernel. + * + * This setting does nothing on devices where the size of the L1 cache and + * shared memory are fixed. + * + * Launching a kernel with a different preference than the most recent + * preference setting may insert a device-side synchronization point. + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 + * (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 + * cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \param config - Requested cache configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceSetCacheConfig, + * ::cuKernelSetCacheConfig + */ +CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config); + +/** + * \brief Gets the context's API version. + * + * Returns a version number in \p version corresponding to the capabilities of + * the context (e.g. 3010 or 3020), which library developers can use to direct + * callers to a specific API version. If \p ctx is NULL, returns the API version + * used to create the currently bound context. + * + * Note that new API versions are only introduced when context capabilities are + * changed that break binary compatibility, so the API version and driver + * version may be different. For example, it is valid for the API version to be + * 3020 while the driver version is 4020. + * + * \param ctx - Context to check + * \param version - Pointer to version + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version); + +/** + * \brief Returns numerical values that correspond to the least and + * greatest stream priorities. + * + * Returns in \p *leastPriority and \p *greatestPriority the numerical values + * that correspond to the least and greatest stream priorities respectively. + * Stream priorities follow a convention where lower numbers imply greater + * priorities. The range of meaningful stream priorities is given by [\p + * *greatestPriority, \p *leastPriority]. If the user attempts to create a + * stream with a priority value that is outside the meaningful range as + * specified by this API, the priority is automatically clamped down or up to + * either \p *leastPriority or \p *greatestPriority respectively. See + * ::cuStreamCreateWithPriority for details on creating a priority stream. A + * NULL may be passed in for \p *leastPriority or \p *greatestPriority if the + * value is not desired. + * + * This function will return '0' in both \p *leastPriority and \p + * *greatestPriority if the current context's device does not support stream + * priorities (see ::cuDeviceGetAttribute). + * + * \param leastPriority - Pointer to an int in which the numerical value for + * least stream priority is returned \param greatestPriority - Pointer to an int + * in which the numerical value for greatest stream priority is returned + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaDeviceGetStreamPriorityRange + */ +CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, + int *greatestPriority); + +/** + * \brief Resets all persisting lines in cache to normal status. + * + * ::cuCtxResetPersistingL2Cache Resets all persisting lines in cache to normal + * status. Takes effect on function return. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuCtxResetPersistingL2Cache(void); + +/** + * \brief Returns the execution affinity setting for the current context. + * + * Returns in \p *pExecAffinity the current value of \p type. The supported + * ::CUexecAffinityType values are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: number of SMs the context is limited to + * use. + * + * \param type - Execution affinity type to query + * \param pExecAffinity - Returned execution affinity + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY + * \notefnerr + * + * \sa + * ::CUexecAffinityParam + */ +CUresult CUDAAPI cuCtxGetExecAffinity(CUexecAffinityParam *pExecAffinity, + CUexecAffinityType type); + +/** @} */ /* END CUDA_CTX */ + +/** + * \defgroup CUDA_CTX_DEPRECATED Context Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated context management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated context management functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Increment a context's usage-count + * + * \deprecated + * + * Note that this function is deprecated and should not be used. + * + * Increments the usage count of the context and passes back a context handle + * in \p *pctx that must be passed to ::cuCtxDetach() when the application is + * done with the context. ::cuCtxAttach() fails if there is no context current + * to the thread. + * + * Currently, the \p flags parameter must be 0. + * + * \param pctx - Returned context handle of the current context + * \param flags - Context attach flags (must be 0) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxDetach, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxAttach(CUcontext *pctx, + unsigned int flags); + +/** + * \brief Decrement a context's usage-count + * + * \deprecated + * + * Note that this function is deprecated and should not be used. + * + * Decrements the usage count of the context \p ctx, and destroys the context + * if the usage count goes to 0. The context must be a handle that was passed + * back by ::cuCtxCreate() or ::cuCtxAttach(), and must be current to the + * calling thread. + * + * \param ctx - Context to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxDetach(CUcontext ctx); + +/** + * \brief Returns the current shared memory configuration for the current + * context. + * + * \deprecated + * + * This function will return in \p pConfig the current size of shared memory + * banks in the current context. On devices with configurable shared memory + * banks, + * ::cuCtxSetSharedMemConfig can be used to change this setting, so that all + * subsequent kernel launches will by default use the new bank size. When + * ::cuCtxGetSharedMemConfig is called on devices without configurable shared + * memory, it will return the fixed bank size of the hardware. + * + * The returned bank configurations can be either: + * - ::CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE: shared memory bank width is + * four bytes. + * - ::CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE: shared memory bank width will + * eight bytes. + * + * \param pConfig - returned shared memory configuration + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuCtxGetSharedMemConfig, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceGetSharedMemConfig + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuCtxGetSharedMemConfig(CUsharedconfig *pConfig); + +/** + * \brief Sets the shared memory configuration for the current context. + * + * \deprecated + * + * On devices with configurable shared memory banks, this function will set + * the context's shared memory bank size which is used for subsequent kernel + * launches. + * + * Changed the shared memory configuration between launches may insert a device + * side synchronization point between those launches. + * + * Changing the shared memory bank size will not increase shared memory usage + * or affect occupancy of kernels, but may have major effects on performance. + * Larger bank sizes will allow for greater potential bandwidth to shared + * memory, but will change what kinds of accesses to shared memory will result + * in bank conflicts. + * + * This function will do nothing on devices with fixed shared memory bank size. + * + * The supported bank configurations are: + * - ::CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE: set bank width to the default + * initial setting (currently, four bytes). + * - ::CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE: set shared memory bank width to + * be natively four bytes. + * - ::CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE: set shared memory bank width + * to be natively eight bytes. + * + * \param config - requested shared memory configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuCtxGetSharedMemConfig, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceSetSharedMemConfig + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuCtxSetSharedMemConfig(CUsharedconfig config); + +/** @} */ /* END CUDA_CTX_DEPRECATED */ + +/** + * \defgroup CUDA_MODULE Module Management + * + * ___MANBRIEF___ module management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the module management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Loads a compute module + * + * Takes a filename \p fname and loads the corresponding module \p module into + * the current context. The CUDA driver API does not attempt to lazily + * allocate the resources needed by a module; if the memory for functions and + * data (constant and global) needed by the module cannot be allocated, + * ::cuModuleLoad() fails. The file should be a \e cubin file as output by + * \b nvcc, or a \e PTX file either as output by \b nvcc or handwritten, or + * a \e fatbin file as output by \b nvcc from toolchain 4.0 or later. + * + * \param module - Returned module + * \param fname - Filename of module to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_FILE_NOT_FOUND, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname); + +/** + * \brief Load a module's data + * + * Takes a pointer \p image and loads the corresponding module \p module into + * the current context. The \p image may be a \e cubin or \e fatbin + * as output by \b nvcc, or a NULL-terminated \e PTX, either as output by \b + * nvcc or hand-written. + * + * \param module - Returned module + * \param image - Module data to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image); + +/** + * \brief Load a module's data with options + * + * Takes a pointer \p image and loads the corresponding module \p module into + * the current context. The \p image may be a \e cubin or \e fatbin + * as output by \b nvcc, or a NULL-terminated \e PTX, either as output by \b + * nvcc or hand-written. + * + * \param module - Returned module + * \param image - Module data to load + * \param numOptions - Number of options + * \param options - Options for JIT + * \param optionValues - Option values for JIT + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, + unsigned int numOptions, + CUjit_option *options, void **optionValues); + +/** + * \brief Load a module's data + * + * Takes a pointer \p fatCubin and loads the corresponding module \p module + * into the current context. The pointer represents a fat binary object, + * which is a collection of different \e cubin and/or \e PTX files, all + * representing the same device code, but compiled and optimized for different + * architectures. + * + * Prior to CUDA 4.0, there was no documented API for constructing and using + * fat binary objects by programmers. Starting with CUDA 4.0, fat binary + * objects can be constructed by providing the -fatbin option to \b nvcc. + * More information can be found in the \b nvcc document. + * + * \param module - Returned module + * \param fatCubin - Fat binary to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin); + +/** + * \brief Unloads a module + * + * Unloads a module \p hmod from the current context. Attempting to unload + * a module which was obtained from the Library Management API such as + * ::cuLibraryGetModule will return ::CUDA_ERROR_NOT_PERMITTED. + * + * \param hmod - Module to unload + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED + * \notefnerr + * \note_destroy_ub + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary + */ +CUresult CUDAAPI cuModuleUnload(CUmodule hmod); + +/** + * CUDA Lazy Loading status + */ +typedef enum CUmoduleLoadingMode_enum { + CU_MODULE_EAGER_LOADING = 0x1, /**< Lazy Kernel Loading is not enabled */ + CU_MODULE_LAZY_LOADING = 0x2, /**< Lazy Kernel Loading is enabled */ +} CUmoduleLoadingMode; + +/** + * \brief Query lazy loading mode + * + * Returns lazy loading mode + * Module loading mode is controlled by CUDA_MODULE_LOADING env variable + * + * \param mode - Returns the lazy loading mode + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa + * ::cuModuleLoad, + */ +CUresult CUDAAPI cuModuleGetLoadingMode(CUmoduleLoadingMode *mode); + +/** + * \brief Returns a function handle + * + * Returns in \p *hfunc the handle of the function of name \p name located in + * module \p hmod. If no function of that name exists, ::cuModuleGetFunction() + * returns ::CUDA_ERROR_NOT_FOUND. + * + * \param hfunc - Returned function handle + * \param hmod - Module to retrieve function from + * \param name - Name of function to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, + const char *name); + +/** + * \brief Returns the number of functions within a module + * + * Returns in \p count the number of functions in \p mod. + * + * \param count - Number of functions found within the module + * \param mod - Module to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + */ +CUresult CUDAAPI cuModuleGetFunctionCount(unsigned int *count, CUmodule mod); + +/** + * \brief Returns the function handles within a module. + * + * Returns in \p functions a maximum number of \p numFunctions function handles + * within \p mod. When function loading mode is set to LAZY the function + * retrieved may be partially loaded. The loading state of a function can be + * queried using ::cuFunctionIsLoaded. CUDA APIs may load the function + * automatically when called with partially loaded function handle which may + * incur additional latency. Alternatively, ::cuFunctionLoad can be used to + * explicitly load a function. The returned function handles become invalid when + * the module is unloaded. + * + * \param functions - Buffer where the function handles are returned to + * \param numFunctions - Maximum number of function handles may be returned to + * the buffer \param mod - Module to query from + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetFunctionCount, + * ::cuFuncIsLoaded, + * ::cuFuncLoad + */ +CUresult CUDAAPI cuModuleEnumerateFunctions(CUfunction *functions, + unsigned int numFunctions, + CUmodule mod); + +/** + * \brief Returns a global pointer from a module + * + * Returns in \p *dptr and \p *bytes the base pointer and size of the + * global of name \p name located in module \p hmod. If no variable of that name + * exists, ::cuModuleGetGlobal() returns ::CUDA_ERROR_NOT_FOUND. + * One of the parameters \p dptr or \p bytes (not both) can be NULL in which + * case it is ignored. + * + * \param dptr - Returned global device pointer + * \param bytes - Returned global size in bytes + * \param hmod - Module to retrieve global from + * \param name - Name of global to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload, + * ::cudaGetSymbolAddress, + * ::cudaGetSymbolSize + */ +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, + CUmodule hmod, const char *name); + +/** + * \brief Creates a pending JIT linker invocation. + * + * If the call is successful, the caller owns the returned CUlinkState, which + * should eventually be destroyed with ::cuLinkDestroy. The + * device code machine size (32 or 64 bit) will match the calling application. + * + * Both linker and compiler options may be specified. Compiler options will + * be applied to inputs to this linker action which must be compiled from PTX. + * The options ::CU_JIT_WALL_TIME, + * ::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, and ::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES + * will accumulate data until the CUlinkState is destroyed. + * + * \p optionValues must remain valid for the life of the CUlinkState if output + * options are used. No other references to inputs are maintained after this + * call returns. + * + * \note For LTO-IR input, only LTO-IR compiled with toolkits prior to CUDA 12.0 + * will be accepted + * + * \param numOptions Size of options arrays + * \param options Array of linker and compiler options + * \param optionValues Array of option values, each cast to void * + * \param stateOut On success, this will contain a CUlinkState to specify + * and complete this action + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuLinkAddData, + * ::cuLinkAddFile, + * ::cuLinkComplete, + * ::cuLinkDestroy + */ +CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, + void **optionValues, CUlinkState *stateOut); + +/** + * \brief Add an input to a pending linker invocation + * + * Ownership of \p data is retained by the caller. No reference is retained to + * any inputs after this call returns. + * + * This method accepts only compiler options, which are used if the data must + * be compiled from PTX, and does not accept any of + * ::CU_JIT_WALL_TIME, ::CU_JIT_INFO_LOG_BUFFER, ::CU_JIT_ERROR_LOG_BUFFER, + * ::CU_JIT_TARGET_FROM_CUCONTEXT, or ::CU_JIT_TARGET. + * + * \note For LTO-IR input, only LTO-IR compiled with toolkits prior to CUDA 12.0 + * will be accepted + * + * \param state A pending linker action. + * \param type The type of the input data. + * \param data The input data. PTX must be NULL-terminated. + * \param size The length of the input data. + * \param name An optional name for this input in log messages. + * \param numOptions Size of options. + * \param options Options to be applied only for this input (overrides + * options from ::cuLinkCreate). \param optionValues Array of option values, + * each cast to void *. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU + * + * \sa ::cuLinkCreate, + * ::cuLinkAddFile, + * ::cuLinkComplete, + * ::cuLinkDestroy + */ +CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, + void **optionValues); + +/** + * \brief Add a file input to a pending linker invocation + * + * No reference is retained to any inputs after this call returns. + * + * This method accepts only compiler options, which are used if the input + * must be compiled from PTX, and does not accept any of + * ::CU_JIT_WALL_TIME, ::CU_JIT_INFO_LOG_BUFFER, ::CU_JIT_ERROR_LOG_BUFFER, + * ::CU_JIT_TARGET_FROM_CUCONTEXT, or ::CU_JIT_TARGET. + * + * This method is equivalent to invoking ::cuLinkAddData on the contents + * of the file. + * + * \note For LTO-IR input, only LTO-IR compiled with toolkits prior to CUDA 12.0 + * will be accepted + * + * \param state A pending linker action + * \param type The type of the input data + * \param path Path to the input file + * \param numOptions Size of options + * \param options Options to be applied only for this input (overrides + * options from ::cuLinkCreate) \param optionValues Array of option values, each + * cast to void * + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_FILE_NOT_FOUND + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU + * + * \sa ::cuLinkCreate, + * ::cuLinkAddData, + * ::cuLinkComplete, + * ::cuLinkDestroy + */ +CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, + const char *path, unsigned int numOptions, + CUjit_option *options, void **optionValues); + +/** + * \brief Complete a pending linker invocation + * + * Completes the pending linker action and returns the cubin image for the + * linked device code, which can be used with ::cuModuleLoadData. The cubin is + * owned by \p state, so it should be loaded before \p state is destroyed via + * ::cuLinkDestroy. This call does not destroy \p state. + * + * \param state A pending linker invocation + * \param cubinOut On success, this will point to the output image + * \param sizeOut Optional parameter to receive the size of the generated image + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuLinkCreate, + * ::cuLinkAddData, + * ::cuLinkAddFile, + * ::cuLinkDestroy, + * ::cuModuleLoadData + */ +CUresult CUDAAPI cuLinkComplete(CUlinkState state, void **cubinOut, + size_t *sizeOut); + +/** + * \brief Destroys state for a JIT linker invocation. + * + * \param state State object for the linker invocation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \sa ::cuLinkCreate + */ +CUresult CUDAAPI cuLinkDestroy(CUlinkState state); + +/** @} */ /* END CUDA_MODULE */ + +/** + * \defgroup CUDA_MODULE_DEPRECATED Module Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated module management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated module management functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns a handle to a texture reference + * + * \deprecated + * + * Returns in \p *pTexRef the handle of the texture reference of name \p name + * in the module \p hmod. If no texture reference of that name exists, + * ::cuModuleGetTexRef() returns ::CUDA_ERROR_NOT_FOUND. This texture reference + * handle should not be destroyed, since it will be destroyed when the module + * is unloaded. + * + * \param pTexRef - Returned texture reference + * \param hmod - Module to retrieve texture reference from + * \param name - Name of texture reference to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa + * ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetSurfRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, + CUmodule hmod, + const char *name); + +/** + * \brief Returns a handle to a surface reference + * + * \deprecated + * + * Returns in \p *pSurfRef the handle of the surface reference of name \p name + * in the module \p hmod. If no surface reference of that name exists, + * ::cuModuleGetSurfRef() returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pSurfRef - Returned surface reference + * \param hmod - Module to retrieve surface reference from + * \param name - Name of surface reference to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa + * ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, + CUmodule hmod, + const char *name); + +/** @} */ /* END CUDA_MODULE_DEPRECATED */ + +/** + * \defgroup CUDA_LIBRARY Library Management + * + * ___MANBRIEF___ library management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the library management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Load a library with specified code and options + * + * Takes a pointer \p code and loads the corresponding library \p library based + * on the application defined library loading mode: + * - If module loading is set to EAGER, via the environment variables described + * in "Module loading", \p library is loaded eagerly into all contexts at the + * time of the call and future contexts at the time of creation until the + * library is unloaded with ::cuLibraryUnload(). + * - If the environment variables are set to LAZY, \p library + * is not immediately loaded onto all existent contexts and will only be + * loaded when a function is needed for that context, such as a kernel launch. + * + * These environment variables are described in the CUDA programming guide under + * the "CUDA environment variables" section. + * + * The \p code may be a \e cubin or \e fatbin as output by \b nvcc, + * or a NULL-terminated \e PTX, either as output by \b nvcc or hand-written. + * + * Options are passed as an array via \p jitOptions and any corresponding + * parameters are passed in \p jitOptionsValues. The number of total JIT options + * is supplied via \p numJitOptions. Any outputs will be returned via \p + * jitOptionsValues. + * + * Library load options are passed as an array via \p libraryOptions and any + * corresponding parameters are passed in \p libraryOptionValues. The number of + * total library load options is supplied via \p numLibraryOptions. + * + * \param library - Returned library + * \param code - Code to load + * \param jitOptions - Options for JIT + * \param jitOptionsValues - Option values for JIT + * \param numJitOptions - Number of options + * \param libraryOptions - Options for loading + * \param libraryOptionValues - Option values for loading + * \param numLibraryOptions - Number of options for loading + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * + * \sa ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx + */ +CUresult CUDAAPI cuLibraryLoadData(CUlibrary *library, const void *code, + CUjit_option *jitOptions, + void **jitOptionsValues, + unsigned int numJitOptions, + CUlibraryOption *libraryOptions, + void **libraryOptionValues, + unsigned int numLibraryOptions); + +/** + * \brief Load a library with specified file and options + * + * Takes a pointer \p code and loads the corresponding library \p library based + * on the application defined library loading mode: + * - If module loading is set to EAGER, via the environment variables described + * in "Module loading", \p library is loaded eagerly into all contexts at the + * time of the call and future contexts at the time of creation until the + * library is unloaded with ::cuLibraryUnload(). + * - If the environment variables are set to LAZY, \p library + * is not immediately loaded onto all existent contexts and will only be + * loaded when a function is needed for that context, such as a kernel launch. + * + * These environment variables are described in the CUDA programming guide under + * the "CUDA environment variables" section. + * + * The file should be a \e cubin file as output by \b nvcc, or a \e PTX file + * either as output by \b nvcc or handwritten, or a \e fatbin file as output by + * \b nvcc. + * + * Options are passed as an array via \p jitOptions and any corresponding + * parameters are passed in \p jitOptionsValues. The number of total options is + * supplied via \p numJitOptions. Any outputs will be returned via \p + * jitOptionsValues. + * + * Library load options are passed as an array via \p libraryOptions and any + * corresponding parameters are passed in \p libraryOptionValues. The number of + * total library load options is supplied via \p numLibraryOptions. + * + * \param library - Returned library + * \param fileName - File to load from + * \param jitOptions - Options for JIT + * \param jitOptionsValues - Option values for JIT + * \param numJitOptions - Number of options + * \param libraryOptions - Options for loading + * \param libraryOptionValues - Option values for loading + * \param numLibraryOptions - Number of options for loading + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryUnload, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx + */ +CUresult CUDAAPI cuLibraryLoadFromFile(CUlibrary *library, const char *fileName, + CUjit_option *jitOptions, + void **jitOptionsValues, + unsigned int numJitOptions, + CUlibraryOption *libraryOptions, + void **libraryOptionValues, + unsigned int numLibraryOptions); + +/** + * \brief Unloads a library + * + * Unloads the library specified with \p library + * + * \param library - Library to unload + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuLibraryUnload(CUlibrary library); + +/** + * \brief Returns a kernel handle + * + * Returns in \p pKernel the handle of the kernel with name \p name located in + * library \p library. If kernel handle is not found, the call returns + * ::CUDA_ERROR_NOT_FOUND. + * + * \param pKernel - Returned kernel handle + * \param library - Library to retrieve kernel from + * \param name - Name of kernel to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction + */ +CUresult CUDAAPI cuLibraryGetKernel(CUkernel *pKernel, CUlibrary library, + const char *name); + +/** + * \brief Returns the number of kernels within a library + * + * Returns in \p count the number of kernels in \p lib. + * + * \param count - Number of kernels found within the library + * \param lib - Library to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + */ +CUresult CUDAAPI cuLibraryGetKernelCount(unsigned int *count, CUlibrary lib); + +/** + * \brief Retrieve the kernel handles within a library. + * + * Returns in \p kernels a maximum number of \p numKernels kernel handles within + * \p lib. The returned kernel handle becomes invalid when the library is + * unloaded. + * + * \param kernels - Buffer where the kernel handles are returned to + * \param numKernels - Maximum number of kernel handles may be returned to the + * buffer \param lib - Library to query from + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuLibraryGetKernelCount + */ +CUresult CUDAAPI cuLibraryEnumerateKernels(CUkernel *kernels, + unsigned int numKernels, + CUlibrary lib); + +/** + * \brief Returns a module handle + * + * Returns in \p pMod the module handle associated with the current context + * located in library \p library. If module handle is not found, the call + * returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pMod - Returned module handle + * \param library - Library to retrieve module from + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuModuleGetFunction + */ +CUresult CUDAAPI cuLibraryGetModule(CUmodule *pMod, CUlibrary library); + +/** + * \brief Returns a function handle + * + * Returns in \p pFunc the handle of the function for the requested kernel \p + * kernel and the current context. If function handle is not found, the call + * returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pFunc - Returned function handle + * \param kernel - Kernel to retrieve function for the requested context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuLibraryGetKernel, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction + */ +CUresult CUDAAPI cuKernelGetFunction(CUfunction *pFunc, CUkernel kernel); + +/** + * \brief Returns a global device pointer + * + * Returns in \p *dptr and \p *bytes the base pointer and size of the global + * with name \p name for the requested library \p library and the current + * context. If no global for the requested name \p name exists, the call returns + * ::CUDA_ERROR_NOT_FOUND. One of the parameters \p dptr or \p bytes (not both) + * can be NULL in which case it is ignored. + * + * \param dptr - Returned global device pointer for the requested context + * \param bytes - Returned global size in bytes + * \param library - Library to retrieve global from + * \param name - Name of global to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuLibraryGetModule, + * cuModuleGetGlobal + */ +CUresult CUDAAPI cuLibraryGetGlobal(CUdeviceptr *dptr, size_t *bytes, + CUlibrary library, const char *name); + +/** + * \brief Returns a pointer to managed memory + * + * Returns in \p *dptr and \p *bytes the base pointer and size of the managed + * memory with name \p name for the requested library \p library. If no managed + * memory with the requested name \p name exists, the call returns + * ::CUDA_ERROR_NOT_FOUND. One of the parameters \p dptr or \p bytes (not both) + * can be NULL in which case it is ignored. Note that managed memory for library + * \p library is shared across devices and is registered when the library is + * loaded into at least one context. + * + * \note The API requires a CUDA context to be present and initialized on at + * least one device. If no context is present, the call returns + * ::CUDA_ERROR_NOT_FOUND. + * + * \param dptr - Returned pointer to the managed memory + * \param bytes - Returned memory size in bytes + * \param library - Library to retrieve managed memory from + * \param name - Name of managed memory to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload + */ +CUresult CUDAAPI cuLibraryGetManaged(CUdeviceptr *dptr, size_t *bytes, + CUlibrary library, const char *name); + +/** + * \brief Returns a pointer to a unified function + * + * Returns in \p *fptr the function pointer to a unified function denoted by \p + * symbol. If no unified function with name \p symbol exists, the call returns + * ::CUDA_ERROR_NOT_FOUND. If there is no device with attribute + * ::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS present in the system, the + * call may return ::CUDA_ERROR_NOT_FOUND. + * + * \param fptr - Returned pointer to a unified function + * \param library - Library to retrieve function pointer memory from + * \param symbol - Name of function pointer to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload + */ +CUresult CUDAAPI cuLibraryGetUnifiedFunction(void **fptr, CUlibrary library, + const char *symbol); + +/** + * \brief Returns information about a kernel + * + * Returns in \p *pi the integer value of the attribute \p attrib for the kernel + * \p kernel for the requested device \p dev. The supported attributes are: + * - ::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK: The maximum number of threads + * per block, beyond which a launch of the kernel would fail. This number + * depends on both the kernel and the requested device. + * - ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES: The size in bytes of + * statically-allocated shared memory per block required by this kernel. + * This does not include dynamically-allocated shared memory requested by + * the user at runtime. + * - ::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES: The size in bytes of user-allocated + * constant memory required by this kernel. + * - ::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES: The size in bytes of local memory + * used by each thread of this kernel. + * - ::CU_FUNC_ATTRIBUTE_NUM_REGS: The number of registers used by each thread + * of this kernel. + * - ::CU_FUNC_ATTRIBUTE_PTX_VERSION: The PTX virtual architecture version for + * which the kernel was compiled. This value is the major PTX version * 10 + * + the minor PTX version, so a PTX version 1.3 function would return the + * value 13. Note that this may return the undefined value of 0 for cubins + * compiled prior to CUDA 3.0. + * - ::CU_FUNC_ATTRIBUTE_BINARY_VERSION: The binary architecture version for + * which the kernel was compiled. This value is the major binary + * version * 10 + the minor binary version, so a binary version 1.3 function + * would return the value 13. Note that this will return a value of 10 for + * legacy cubins that do not have a properly-encoded binary architecture + * version. + * - ::CU_FUNC_CACHE_MODE_CA: The attribute to indicate whether the kernel has + * been compiled with user specified option "-Xptxas --dlcm=ca" set. + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: The maximum size in + * bytes of dynamically-allocated shared memory. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: Preferred shared + * memory-L1 cache split ratio in percent of total shared memory. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET: If this attribute is set, the + * kernel must launch with a valid cluster size specified. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED: Indicates whether + * the function can be launched with non-portable cluster size. 1 is allowed, + * 0 is disallowed. A non-portable cluster size may only function on the + * specific SKUs the program is tested on. The launch might fail if the + * program is run on a different hardware platform. CUDA API provides + * cudaOccupancyMaxActiveClusters to assist with checking whether the desired + * size can be launched on the current device. A portable cluster size is + * guaranteed to be functional on all compute capabilities higher than the + * target compute capability. The portable cluster size for sm_90 is 8 blocks + * per cluster. This value may increase for future compute capabilities. The + * specific hardware unit may support higher cluster sizes that’s not + * guaranteed to be portable. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is + * CUclusterSchedulingPolicy. + * + * \note If another thread is trying to set the same attribute on the same + * device using + * ::cuKernelSetAttribute() simultaneously, the attribute query will give the + * old or new value depending on the interleavings chosen by the OS scheduler + * and memory consistency. + * + * \param pi - Returned attribute value + * \param attrib - Attribute requested + * \param kernel - Kernel to query attribute of + * \param dev - Device to query attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuKernelSetAttribute, + * ::cuLibraryGetKernel, + * ::cuLaunchKernel, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction, + * ::cuFuncGetAttribute + */ +CUresult CUDAAPI cuKernelGetAttribute(int *pi, CUfunction_attribute attrib, + CUkernel kernel, CUdevice dev); + +/** + * \brief Sets information about a kernel + * + * This call sets the value of a specified attribute \p attrib on the kernel \p + * kernel for the requested device \p dev to an integer value specified by \p + * val. This function returns CUDA_SUCCESS if the new value of the attribute + * could be successfully set. If the set fails, this call will return an error. + * Not all attributes can have values set. Attempting to set a value on a + * read-only attribute will result in an error (CUDA_ERROR_INVALID_VALUE) + * + * Note that attributes set using ::cuFuncSetAttribute() will override the + * attribute set by this API irrespective of whether the call to + * ::cuFuncSetAttribute() is made before or after this API call. However, + * ::cuKernelGetAttribute() will always return the attribute value set by this + * API. + * + * Supported attributes are: + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: This is the maximum size + * in bytes of dynamically-allocated shared memory. The value should contain the + * requested maximum size of dynamically-allocated shared memory. The sum of + * this value and the function attribute ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES + * cannot exceed the device attribute + * ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN. The maximal size of + * requestable dynamic shared memory may differ by GPU architecture. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: On devices where the + * L1 cache and shared memory use the same hardware resources, this sets the + * shared memory carveout preference, in percent of the total shared memory. See + * ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR This is only a + * hint, and the driver can choose a different ratio if required to execute the + * function. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is + * CUclusterSchedulingPolicy. + * + * \note The API has stricter locking requirements in comparison to its legacy + * counterpart + * ::cuFuncSetAttribute() due to device-wide semantics. If multiple threads are + * trying to set the same attribute on the same device simultaneously, the + * attribute setting will depend on the interleavings chosen by the OS scheduler + * and memory consistency. + * + * \param attrib - Attribute requested + * \param val - Value to set + * \param kernel - Kernel to set attribute of + * \param dev - Device to set attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuKernelGetAttribute, + * ::cuLibraryGetKernel, + * ::cuLaunchKernel, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction, + * ::cuFuncSetAttribute + */ +CUresult CUDAAPI cuKernelSetAttribute(CUfunction_attribute attrib, int val, + CUkernel kernel, CUdevice dev); + +/** + * \brief Sets the preferred cache configuration for a device kernel. + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets through \p config the preferred cache configuration for + * the device kernel \p kernel on the requested device \p dev. This is only a + * preference. The driver will use the requested configuration if possible, but + * it is free to choose a different configuration if required to execute \p + * kernel. Any context-wide preference set via ::cuCtxSetCacheConfig() will be + * overridden by this per-kernel setting. + * + * Note that attributes set using ::cuFuncSetCacheConfig() will override the + * attribute set by this API irrespective of whether the call to + * ::cuFuncSetCacheConfig() is made before or after this API call. + * + * This setting does nothing on devices where the size of the L1 cache and + * shared memory are fixed. + * + * Launching a kernel with a different preference than the most recent + * preference setting may insert a device-side synchronization point. + * + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 + * (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 + * cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \note The API has stricter locking requirements in comparison to its legacy + * counterpart + * ::cuFuncSetCacheConfig() due to device-wide semantics. If multiple threads + * are trying to set a config on the same device simultaneously, the cache + * config setting will depend on the interleavings chosen by the OS scheduler + * and memory consistency. + * + * \param kernel - Kernel to configure cache for + * \param config - Requested cache configuration + * \param dev - Device to set attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuLibraryGetKernel, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction, + * ::cuFuncSetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuLaunchKernel + */ +CUresult CUDAAPI cuKernelSetCacheConfig(CUkernel kernel, CUfunc_cache config, + CUdevice dev); + +/** + * \brief Returns the function name for a ::CUkernel handle + * + * Returns in \p **name the function name associated with the kernel handle \p + * hfunc . The function name is returned as a null-terminated string. The + * returned name is only valid when the kernel handle is valid. If the library + * is unloaded or reloaded, one must call the API again to get the updated name. + * This API may return a mangled name if the function is not declared as having + * C linkage. If either \p **name or \p hfunc is NULL, + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param name - The returned name of the function + * \param hfunc - The function handle to retrieve the name for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + */ +CUresult CUDAAPI cuKernelGetName(const char **name, CUkernel hfunc); + +/** + * \brief Returns the offset and size of a kernel parameter in the device-side + * parameter layout + * + * Queries the kernel parameter at \p paramIndex into \p kernel's list of + * parameters, and returns in \p paramOffset and \p paramSize the offset and + * size, respectively, where the parameter will reside in the device-side + * parameter layout. This information can be used to update kernel node + * parameters from the device via ::cudaGraphKernelNodeSetParam() and + * ::cudaGraphKernelNodeUpdatesApply(). \p paramIndex must be less than the + * number of parameters that \p kernel takes. \p paramSize can be set to NULL if + * only the parameter offset is desired. + * + * \param kernel - The kernel to query + * \param paramIndex - The parameter index to query + * \param paramOffset - Returns the offset into the device-side parameter layout + * at which the parameter resides \param paramSize - Optionally returns the + * size of the parameter in the device-side parameter layout + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuFuncGetParamInfo + */ +CUresult CUDAAPI cuKernelGetParamInfo(CUkernel kernel, size_t paramIndex, + size_t *paramOffset, size_t *paramSize); +/** @} */ /* END CUDA_LIBRARY */ + +/** + * \defgroup CUDA_MEM Memory Management + * + * ___MANBRIEF___ memory management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the memory management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Gets free and total memory + * + * Returns in \p *total the total amount of memory available to the the current + * context. Returns in \p *free the amount of memory on the device that is free + * according to the OS. CUDA is not guaranteed to be able to allocate all of the + * memory that the OS reports as free. In a multi-tenet situation, free estimate + * returned is prone to race condition where a new allocation/free done by a + * different process or a different thread in the same process between the time + * when free memory was estimated and reported, will result in deviation in free + * value reported and actual free memory. + * + * The integrated GPU on Tegra shares memory with CPU and other component + * of the SoC. The free and total values returned by the API excludes + * the SWAP memory space maintained by the OS on some platforms. + * The OS may move some of the memory pages into swap area as the GPU or + * CPU allocate or access memory. See Tegra app note on how to calculate + * total and free memory on Tegra. + * + * \param free - Returned free memory in bytes + * \param total - Returned total memory in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemGetInfo + */ +CUresult CUDAAPI cuMemGetInfo(size_t *free, size_t *total); + +/** + * \brief Allocates device memory + * + * Allocates \p bytesize bytes of linear memory on the device and returns in + * \p *dptr a pointer to the allocated memory. The allocated memory is suitably + * aligned for any kind of variable. The memory is not cleared. If \p bytesize + * is 0, ::cuMemAlloc() returns ::CUDA_ERROR_INVALID_VALUE. + * + * \param dptr - Returned device pointer + * \param bytesize - Requested allocation size in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMalloc + */ +CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, size_t bytesize); + +/** + * \brief Allocates pitched device memory + * + * Allocates at least \p WidthInBytes * \p Height bytes of linear memory on + * the device and returns in \p *dptr a pointer to the allocated memory. The + * function may pad the allocation to ensure that corresponding pointers in + * any given row will continue to meet the alignment requirements for + * coalescing as the address is updated from row to row. \p ElementSizeBytes + * specifies the size of the largest reads and writes that will be performed + * on the memory range. \p ElementSizeBytes may be 4, 8 or 16 (since coalesced + * memory transactions are not possible on other data sizes). If + * \p ElementSizeBytes is smaller than the actual read/write size of a kernel, + * the kernel will run correctly, but possibly at reduced speed. The pitch + * returned in \p *pPitch by ::cuMemAllocPitch() is the width in bytes of the + * allocation. The intended usage of pitch is as a separate parameter of the + * allocation, used to compute addresses within the 2D array. Given the row + * and column of an array element of type \b T, the address is computed as: + * \code + T* pElement = (T*)((char*)BaseAddress + Row * Pitch) + Column; + * \endcode + * + * The pitch returned by ::cuMemAllocPitch() is guaranteed to work with + * ::cuMemcpy2D() under all circumstances. For allocations of 2D arrays, it is + * recommended that programmers consider performing pitch allocations using + * ::cuMemAllocPitch(). Due to alignment restrictions in the hardware, this is + * especially true if the application will be performing 2D memory copies + * between different regions of device memory (whether linear memory or CUDA + * arrays). + * + * The byte alignment of the pitch returned by ::cuMemAllocPitch() is guaranteed + * to match or exceed the alignment requirement for texture binding with + * ::cuTexRefSetAddress2D(). + * + * \param dptr - Returned device pointer + * \param pPitch - Returned pitch of allocation in bytes + * \param WidthInBytes - Requested allocation width in bytes + * \param Height - Requested allocation height in rows + * \param ElementSizeBytes - Size of largest reads/writes for range + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMallocPitch + */ +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, + size_t WidthInBytes, size_t Height, + unsigned int ElementSizeBytes); + +/** + * \brief Frees device memory + * + * Frees the memory space pointed to by \p dptr, which must have been returned + * by a previous call to one of the following memory allocation APIs - + * ::cuMemAlloc(), + * ::cuMemAllocPitch(), ::cuMemAllocManaged(), ::cuMemAllocAsync(), + * ::cuMemAllocFromPoolAsync() + * + * Note - This API will not perform any implicit synchronization when the + * pointer was allocated with + * ::cuMemAllocAsync or ::cuMemAllocFromPoolAsync. Callers must ensure that all + * accesses to the pointer have completed before invoking ::cuMemFree. For best + * performance and memory reuse, users should use ::cuMemFreeAsync to free + * memory allocated via the stream ordered memory allocator. + * + * \param dptr - Pointer to memory to free + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemAllocManaged, ::cuMemAllocAsync, + * ::cuMemAllocFromPoolAsync, + * ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, ::cuMemcpy3D, + * ::cuMemcpy3DAsync, + * ::cuMemcpyAtoA, ::cuMemcpyAtoD, ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, + * ::cuMemcpyDtoA, + * ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, + * ::cuMemcpyHtoA, + * ::cuMemcpyHtoAAsync, ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, ::cuMemFreeAsync, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaFree + */ +CUresult CUDAAPI cuMemFree(CUdeviceptr dptr); + +/** + * \brief Get information on memory allocations + * + * Returns the base address in \p *pbase and size in \p *psize of the + * allocation by ::cuMemAlloc() or ::cuMemAllocPitch() that contains the input + * pointer \p dptr. Both parameters \p pbase and \p psize are optional. If one + * of them is NULL, it is ignored. + * + * \param pbase - Returned base address + * \param psize - Returned size of device memory allocation + * \param dptr - Device pointer to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32 + */ +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, + CUdeviceptr dptr); + +/** + * \brief Allocates page-locked host memory + * + * Allocates \p bytesize bytes of host memory that is page-locked and + * accessible to the device. The driver tracks the virtual memory ranges + * allocated with this function and automatically accelerates calls to + * functions such as ::cuMemcpy(). Since the memory can be accessed directly by + * the device, it can be read or written with much higher bandwidth than + * pageable memory obtained with functions such as ::malloc(). + * + * On systems where + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES is true, + * ::cuMemAllocHost may not page-lock the allocated memory. + * + * Page-locking excessive amounts of memory with ::cuMemAllocHost() may degrade + * system performance, since it reduces the amount of memory available to the + * system for paging. As a result, this function is best used sparingly to + * allocate staging areas for data exchange between host and device. + * + * Note all host memory allocated using ::cuMemAllocHost() will automatically + * be immediately accessible to all contexts on all devices which support + * unified addressing (as may be queried using + * ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING). The device pointer that may be + * used to access this host memory from those contexts is always equal to the + * returned host pointer \p *pp. See \ref CUDA_UNIFIED for additional details. + * + * \param pp - Returned pointer to host memory + * \param bytesize - Requested allocation size in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMallocHost + */ +CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize); + +/** + * \brief Frees page-locked host memory + * + * Frees the memory space pointed to by \p p, which must have been returned by + * a previous call to ::cuMemAllocHost(). + * + * \param p - Pointer to memory to free + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaFreeHost + */ +CUresult CUDAAPI cuMemFreeHost(void *p); + +/** + * \brief Allocates page-locked host memory + * + * Allocates \p bytesize bytes of host memory that is page-locked and accessible + * to the device. The driver tracks the virtual memory ranges allocated with + * this function and automatically accelerates calls to functions such as + * ::cuMemcpyHtoD(). Since the memory can be accessed directly by the device, + * it can be read or written with much higher bandwidth than pageable memory + * obtained with functions such as ::malloc(). + * + * On systems where + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES is true, + * ::cuMemHostAlloc may not page-lock the allocated memory. + * + * Page-locking excessive amounts of memory may degrade system performance, + * since it reduces the amount of memory available to the system for paging. + * As a result, this function is best used sparingly to allocate staging areas + * for data exchange between host and device. + * + * The \p Flags parameter enables different options to be specified that + * affect the allocation, as follows. + * + * - ::CU_MEMHOSTALLOC_PORTABLE: The memory returned by this call will be + * considered as pinned memory by all CUDA contexts, not just the one that + * performed the allocation. + * + * - ::CU_MEMHOSTALLOC_DEVICEMAP: Maps the allocation into the CUDA address + * space. The device pointer to the memory may be obtained by calling + * ::cuMemHostGetDevicePointer(). + * + * - ::CU_MEMHOSTALLOC_WRITECOMBINED: Allocates the memory as write-combined + * (WC). WC memory can be transferred across the PCI Express bus more + * quickly on some system configurations, but cannot be read efficiently by + * most CPUs. WC memory is a good option for buffers that will be written by + * the CPU and read by the GPU via mapped pinned memory or host->device + * transfers. + * + * All of these flags are orthogonal to one another: a developer may allocate + * memory that is portable, mapped and/or write-combined with no restrictions. + * + * The ::CU_MEMHOSTALLOC_DEVICEMAP flag may be specified on CUDA contexts for + * devices that do not support mapped pinned memory. The failure is deferred + * to ::cuMemHostGetDevicePointer() because the memory may be mapped into + * other CUDA contexts via the ::CU_MEMHOSTALLOC_PORTABLE flag. + * + * The memory allocated by this function must be freed with ::cuMemFreeHost(). + * + * Note all host memory allocated using ::cuMemHostAlloc() will automatically + * be immediately accessible to all contexts on all devices which support + * unified addressing (as may be queried using + * ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING). Unless the flag + * ::CU_MEMHOSTALLOC_WRITECOMBINED is specified, the device pointer that may be + * used to access this host memory from those contexts is always equal to the + * returned host pointer \p *pp. If the flag ::CU_MEMHOSTALLOC_WRITECOMBINED is + * specified, then the function ::cuMemHostGetDevicePointer() must be used to + * query the device pointer, even if the context supports unified addressing. + * See \ref CUDA_UNIFIED for additional details. + * + * \param pp - Returned pointer to host memory + * \param bytesize - Requested allocation size in bytes + * \param Flags - Flags for allocation request + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaHostAlloc + */ +CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, unsigned int Flags); + +/** + * \brief Passes back device pointer of mapped pinned memory + * + * Passes back the device pointer \p pdptr corresponding to the mapped, pinned + * host buffer \p p allocated by ::cuMemHostAlloc. + * + * ::cuMemHostGetDevicePointer() will fail if the ::CU_MEMHOSTALLOC_DEVICEMAP + * flag was not specified at the time the memory was allocated, or if the + * function is called on a GPU that does not support mapped pinned memory. + * + * For devices that have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, the memory + * can also be accessed from the device using the host pointer \p p. + * The device pointer returned by ::cuMemHostGetDevicePointer() may or may not + * match the original host pointer \p p and depends on the devices visible to + * the application. If all devices visible to the application have a non-zero + * value for the device attribute, the device pointer returned by + * ::cuMemHostGetDevicePointer() will match the original pointer \p p. If any + * device visible to the application has a zero value for the device attribute, + * the device pointer returned by + * ::cuMemHostGetDevicePointer() will not match the original host pointer \p p, + * but it will be suitable for use on all devices provided Unified Virtual + * Addressing is enabled. In such systems, it is valid to access the memory + * using either pointer on devices that have a non-zero value for the device + * attribute. Note however that such devices should access the memory using only + * one of the two pointers and not both. + * + * \p Flags provides for future releases. For now, it must be set to 0. + * + * \param pdptr - Returned device pointer + * \param p - Host pointer + * \param Flags - Options (must be 0) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaHostGetDevicePointer + */ +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, + unsigned int Flags); + +/** + * \brief Passes back flags that were used for a pinned allocation + * + * Passes back the flags \p pFlags that were specified when allocating + * the pinned host buffer \p p allocated by ::cuMemHostAlloc. + * + * ::cuMemHostGetFlags() will fail if the pointer does not reside in + * an allocation performed by ::cuMemAllocHost() or ::cuMemHostAlloc(). + * + * \param pFlags - Returned flags word + * \param p - Host pointer + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuMemAllocHost, + * ::cuMemHostAlloc, + * ::cudaHostGetFlags + */ +CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p); + +/** + * \brief Allocates memory that will be automatically managed by the Unified + * Memory system + * + * Allocates \p bytesize bytes of managed memory on the device and returns in + * \p *dptr a pointer to the allocated memory. If the device doesn't support + * allocating managed memory, ::CUDA_ERROR_NOT_SUPPORTED is returned. Support + * for managed memory can be queried using the device attribute + * ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY. The allocated memory is suitably + * aligned for any kind of variable. The memory is not cleared. If \p bytesize + * is 0, ::cuMemAllocManaged returns ::CUDA_ERROR_INVALID_VALUE. The pointer + * is valid on the CPU and on all GPUs in the system that support managed + * memory. All accesses to this pointer must obey the Unified Memory programming + * model. + * + * \p flags specifies the default stream association for this allocation. + * \p flags must be one of ::CU_MEM_ATTACH_GLOBAL or ::CU_MEM_ATTACH_HOST. If + * ::CU_MEM_ATTACH_GLOBAL is specified, then this memory is accessible from + * any stream on any device. If ::CU_MEM_ATTACH_HOST is specified, then the + * allocation should not be accessed from devices that have a zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS; an explicit + * call to + * ::cuStreamAttachMemAsync will be required to enable access on such devices. + * + * If the association is later changed via ::cuStreamAttachMemAsync to + * a single stream, the default association as specified during + * ::cuMemAllocManaged is restored when that stream is destroyed. For + * __managed__ variables, the default association is always + * ::CU_MEM_ATTACH_GLOBAL. Note that destroying a stream is an asynchronous + * operation, and as a result, the change to default association won't happen + * until all work in the stream has completed. + * + * Memory allocated with ::cuMemAllocManaged should be released with + * ::cuMemFree. + * + * Device memory oversubscription is possible for GPUs that have a non-zero + * value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Managed memory on such GPUs + * may be evicted from device memory to host memory at any time by the Unified + * Memory driver in order to make room for other allocations. + * + * In a system where all GPUs have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, managed memory may not be + * populated when this API returns and instead may be populated on access. In + * such systems, managed memory can migrate to any processor's memory at any + * time. The Unified Memory driver will employ heuristics to maintain data + * locality and prevent excessive page faults to the extent possible. The + * application can also guide the driver about memory usage patterns via + * ::cuMemAdvise. The application can also explicitly migrate memory to a + * desired processor's memory via + * ::cuMemPrefetchAsync. + * + * In a multi-GPU system where all of the GPUs have a zero value for the device + * attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS and all the GPUs have + * peer-to-peer support with each other, the physical storage for managed memory + * is created on the GPU which is active at the time ::cuMemAllocManaged is + * called. All other GPUs will reference the data at reduced bandwidth via peer + * mappings over the PCIe bus. The Unified Memory driver does not migrate memory + * among such GPUs. + * + * In a multi-GPU system where not all GPUs have peer-to-peer support with each + * other and where the value of the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS is zero for at least one of + * those GPUs, the location chosen for physical storage of managed memory is + * system-dependent. + * - On Linux, the location chosen will be device memory as long as the current + * set of active contexts are on devices that either have peer-to-peer support + * with each other or have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. If there is an active + * context on a GPU that does not have a non-zero value for that device + * attribute and it does not have peer-to-peer support with the other devices + * that have active contexts on them, then the location for physical storage + * will be 'zero-copy' or host memory. Note that this means that managed memory + * that is located in device memory is migrated to host memory if a new context + * is created on a GPU that doesn't have a non-zero value for the device + * attribute and does not support peer-to-peer with at least one of the other + * devices that has an active context. This in turn implies that context + * creation may fail if there is insufficient host memory to migrate all managed + * allocations. + * - On Windows, the physical storage is always created in 'zero-copy' or host + * memory. All GPUs will reference the data at reduced bandwidth over the PCIe + * bus. In these circumstances, use of the environment variable + * CUDA_VISIBLE_DEVICES is recommended to restrict CUDA to only use those GPUs + * that have peer-to-peer support. Alternatively, users can also set + * CUDA_MANAGED_FORCE_DEVICE_ALLOC to a non-zero value to force the driver to + * always use device memory for physical storage. When this environment variable + * is set to a non-zero value, all contexts created in that process on devices + * that support managed memory have to be peer-to-peer compatible with each + * other. Context creation will fail if a context is created on a device that + * supports managed memory and is not peer-to-peer compatible with any of the + * other managed memory supporting devices on which contexts were previously + * created, even if those contexts have been destroyed. These environment + * variables are described in the CUDA programming guide under the "CUDA + * environment variables" section. + * - On ARM, managed memory is not available on discrete gpu with Drive PX-2. + * + * \param dptr - Returned device pointer + * \param bytesize - Requested allocation size in bytes + * \param flags - Must be one of ::CU_MEM_ATTACH_GLOBAL or + * ::CU_MEM_ATTACH_HOST + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cuDeviceGetAttribute, ::cuStreamAttachMemAsync, + * ::cudaMallocManaged + */ +CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, + unsigned int flags); + +/** + * \brief Registers a callback function to receive async notifications + * + * Registers \p callbackFunc to receive async notifications. + * + * The \p userData parameter is passed to the callback function at async + * notification time. Likewise, \p callback is also passed to the callback + * function to distinguish between multiple registered callbacks. + * + * The callback function being registered should be designed to return quickly + * (~10ms). Any long running tasks should be queued for execution on an + * application thread. + * + * Callbacks may not call cuDeviceRegisterAsyncNotification or + * cuDeviceUnregisterAsyncNotification. Doing so will result in + * ::CUDA_ERROR_NOT_PERMITTED. Async notification callbacks execute in an + * undefined order and may be serialized. + * + * Returns in \p *callback a handle representing the registered callback + * instance. + * + * \param device - The device on which to register the callback + * \param callbackFunc - The function to register as a callback + * \param userData - A generic pointer to user data. This is passed into the + * callback function. \param callback - A handle representing the registered + * callback instance + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_NOT_SUPPORTED + * ::CUDA_ERROR_INVALID_DEVICE + * ::CUDA_ERROR_INVALID_VALUE + * ::CUDA_ERROR_NOT_PERMITTED + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cuDeviceUnregisterAsyncNotification + */ +CUresult CUDAAPI cuDeviceRegisterAsyncNotification( + CUdevice device, CUasyncCallback callbackFunc, void *userData, + CUasyncCallbackHandle *callback); + +/** + * \brief Unregisters an async notification callback + * + * Unregisters \p callback so that the corresponding callback function will stop + * receiving async notifications. + * + * \param device - The device from which to remove \p callback. + * \param callback - The callback instance to unregister from receiving async + * notifications. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_NOT_SUPPORTED + * ::CUDA_ERROR_INVALID_DEVICE + * ::CUDA_ERROR_INVALID_VALUE + * ::CUDA_ERROR_NOT_PERMITTED + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cuDeviceRegisterAsyncNotification + */ +CUresult CUDAAPI cuDeviceUnregisterAsyncNotification( + CUdevice device, CUasyncCallbackHandle callback); + +/** + * \brief Returns a handle to a compute device + * + * Returns in \p *device a device handle given a PCI bus ID string. + * + * \param dev - Returned device handle + * + * \param pciBusId - String in one of the following forms: + * [domain]:[bus]:[device].[function] + * [domain]:[bus]:[device] + * [bus]:[device].[function] + * where \p domain, \p bus, \p device, and \p function are all hexadecimal + * values + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGet, + * ::cuDeviceGetAttribute, + * ::cuDeviceGetPCIBusId, + * ::cudaDeviceGetByPCIBusId + */ +CUresult CUDAAPI cuDeviceGetByPCIBusId(CUdevice *dev, const char *pciBusId); + +/** + * \brief Returns a PCI Bus Id string for the device + * + * Returns an ASCII string identifying the device \p dev in the NULL-terminated + * string pointed to by \p pciBusId. \p len specifies the maximum length of the + * string that may be returned. + * + * \param pciBusId - Returned identifier string for the device in the following + * format [domain]:[bus]:[device].[function] where \p domain, \p bus, \p device, + * and \p function are all hexadecimal values. pciBusId should be large enough + * to store 13 characters including the NULL-terminator. + * + * \param len - Maximum length of string to store in \p name + * + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGet, + * ::cuDeviceGetAttribute, + * ::cuDeviceGetByPCIBusId, + * ::cudaDeviceGetPCIBusId + */ +CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev); + +/** + * \brief Gets an interprocess handle for a previously allocated event + * + * Takes as input a previously allocated event. This event must have been + * created with the ::CU_EVENT_INTERPROCESS and ::CU_EVENT_DISABLE_TIMING + * flags set. This opaque handle may be copied into other processes and + * opened with ::cuIpcOpenEventHandle to allow efficient hardware + * synchronization between GPU work in different processes. + * + * After the event has been opened in the importing process, + * ::cuEventRecord, ::cuEventSynchronize, ::cuStreamWaitEvent and + * ::cuEventQuery may be used in either process. Performing operations + * on the imported event after the exported event has been freed + * with ::cuEventDestroy will result in undefined behavior. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED + * + * \param pHandle - Pointer to a user allocated CUipcEventHandle + * in which to return the opaque event handle + * \param event - Event allocated with ::CU_EVENT_INTERPROCESS and + * ::CU_EVENT_DISABLE_TIMING flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuEventCreate, + * ::cuEventDestroy, + * ::cuEventSynchronize, + * ::cuEventQuery, + * ::cuStreamWaitEvent, + * ::cuIpcOpenEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcOpenMemHandle, + * ::cuIpcCloseMemHandle, + * ::cudaIpcGetEventHandle + */ +CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event); + +/** + * \brief Opens an interprocess event handle for use in the current process + * + * Opens an interprocess event handle exported from another process with + * ::cuIpcGetEventHandle. This function returns a ::CUevent that behaves like + * a locally created event with the ::CU_EVENT_DISABLE_TIMING flag specified. + * This event must be freed with ::cuEventDestroy. + * + * Performing operations on the imported event after the exported event has + * been freed with ::cuEventDestroy will result in undefined behavior. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED + * + * \param phEvent - Returns the imported event + * \param handle - Interprocess handle to open + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuEventCreate, + * ::cuEventDestroy, + * ::cuEventSynchronize, + * ::cuEventQuery, + * ::cuStreamWaitEvent, + * ::cuIpcGetEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcOpenMemHandle, + * ::cuIpcCloseMemHandle, + * ::cudaIpcOpenEventHandle + */ +CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, + CUipcEventHandle handle); + +/** + * \brief Gets an interprocess memory handle for an existing device memory + * allocation + * + * Takes a pointer to the base of an existing device memory allocation created + * with ::cuMemAlloc and exports it for use in another process. This is a + * lightweight operation and may be called multiple times on an allocation + * without adverse effects. + * + * If a region of memory is freed with ::cuMemFree and a subsequent call + * to ::cuMemAlloc returns memory with the same device address, + * ::cuIpcGetMemHandle will return a unique handle for the + * new memory. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED + * + * \param pHandle - Pointer to user allocated ::CUipcMemHandle to return + * the handle in. + * \param dptr - Base pointer to previously allocated device memory + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuIpcGetEventHandle, + * ::cuIpcOpenEventHandle, + * ::cuIpcOpenMemHandle, + * ::cuIpcCloseMemHandle, + * ::cudaIpcGetMemHandle + */ +CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr); + +/** + * \brief Opens an interprocess memory handle exported from another process + * and returns a device pointer usable in the local process. + * + * Maps memory exported from another process with ::cuIpcGetMemHandle into + * the current device address space. For contexts on different devices + * ::cuIpcOpenMemHandle can attempt to enable peer access between the + * devices as if the user called ::cuCtxEnablePeerAccess. This behavior is + * controlled by the ::CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag. + * ::cuDeviceCanAccessPeer can determine if a mapping is possible. + * + * Contexts that may open ::CUipcMemHandles are restricted in the following way. + * ::CUipcMemHandles from each ::CUdevice in a given process may only be opened + * by one ::CUcontext per ::CUdevice per other process. + * + * If the memory handle has already been opened by the current context, the + * reference count on the handle is incremented by 1 and the existing device + * pointer is returned. + * + * Memory returned from ::cuIpcOpenMemHandle must be freed with + * ::cuIpcCloseMemHandle. + * + * Calling ::cuMemFree on an exported memory region before calling + * ::cuIpcCloseMemHandle in the importing context will result in undefined + * behavior. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED + * + * \param pdptr - Returned device pointer + * \param handle - ::CUipcMemHandle to open + * \param Flags - Flags for this operation. Must be specified as + * ::CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_TOO_MANY_PEERS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \note No guarantees are made about the address returned in \p *pdptr. + * In particular, multiple processes may not receive the same address for the + * same \p handle. + * + * \sa + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuIpcGetEventHandle, + * ::cuIpcOpenEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcCloseMemHandle, + * ::cuCtxEnablePeerAccess, + * ::cuDeviceCanAccessPeer, + * ::cudaIpcOpenMemHandle + */ +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags); + +/** + * \brief Attempts to close memory mapped with ::cuIpcOpenMemHandle + * + * Decrements the reference count of the memory returned by ::cuIpcOpenMemHandle + * by 1. When the reference count reaches 0, this API unmaps the memory. The + * original allocation in the exporting process as well as imported mappings in + * other processes will be unaffected. + * + * Any resources used to enable peer access will be freed if this is the + * last mapping using them. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED + * + * \param dptr - Device pointer returned by ::cuIpcOpenMemHandle + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \sa + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuIpcGetEventHandle, + * ::cuIpcOpenEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcOpenMemHandle, + * ::cudaIpcCloseMemHandle + */ +CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); + +/** + * \brief Registers an existing host memory range for use by CUDA + * + * Page-locks the memory range specified by \p p and \p bytesize and maps it + * for the device(s) as specified by \p Flags. This memory range also is added + * to the same tracking mechanism as ::cuMemHostAlloc to automatically + * accelerate calls to functions such as ::cuMemcpyHtoD(). Since the memory can + * be accessed directly by the device, it can be read or written with much + * higher bandwidth than pageable memory that has not been registered. + * Page-locking excessive amounts of memory may degrade system performance, + * since it reduces the amount of memory available to the system for paging. As + * a result, this function is best used sparingly to register staging areas for + * data exchange between host and device. + * + * On systems where + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES is true, + * ::cuMemHostRegister will not page-lock the memory range specified by \p ptr + * but only populate unpopulated pages. + * + * The \p Flags parameter enables different options to be specified that + * affect the allocation, as follows. + * + * - ::CU_MEMHOSTREGISTER_PORTABLE: The memory returned by this call will be + * considered as pinned memory by all CUDA contexts, not just the one that + * performed the allocation. + * + * - ::CU_MEMHOSTREGISTER_DEVICEMAP: Maps the allocation into the CUDA address + * space. The device pointer to the memory may be obtained by calling + * ::cuMemHostGetDevicePointer(). + * + * - ::CU_MEMHOSTREGISTER_IOMEMORY: The pointer is treated as pointing to some + * I/O memory space, e.g. the PCI Express resource of a 3rd party device. + * + * - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to + * memory that is considered read-only by the device. On platforms without + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this + * flag is required in order to register memory mapped to the CPU as read-only. + * Support for the use of this flag can be queried from the device attribute + * ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag + * with a current context associated with a device that does not have this + * attribute set will cause ::cuMemHostRegister to error with + * CUDA_ERROR_NOT_SUPPORTED. + * + * All of these flags are orthogonal to one another: a developer may page-lock + * memory that is portable or mapped with no restrictions. + * + * The ::CU_MEMHOSTREGISTER_DEVICEMAP flag may be specified on CUDA contexts for + * devices that do not support mapped pinned memory. The failure is deferred + * to ::cuMemHostGetDevicePointer() because the memory may be mapped into + * other CUDA contexts via the ::CU_MEMHOSTREGISTER_PORTABLE flag. + * + * For devices that have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, the memory + * can also be accessed from the device using the host pointer \p p. + * The device pointer returned by ::cuMemHostGetDevicePointer() may or may not + * match the original host pointer \p ptr and depends on the devices visible to + * the application. If all devices visible to the application have a non-zero + * value for the device attribute, the device pointer returned by + * ::cuMemHostGetDevicePointer() will match the original pointer \p ptr. If any + * device visible to the application has a zero value for the device attribute, + * the device pointer returned by + * ::cuMemHostGetDevicePointer() will not match the original host pointer \p + * ptr, but it will be suitable for use on all devices provided Unified Virtual + * Addressing is enabled. In such systems, it is valid to access the memory + * using either pointer on devices that have a non-zero value for the device + * attribute. Note however that such devices should access the memory using only + * of the two pointers and not both. + * + * The memory page-locked by this function must be unregistered with + * ::cuMemHostUnregister(). + * + * \param p - Host pointer to memory to page-lock + * \param bytesize - Size in bytes of the address range to page-lock + * \param Flags - Flags for allocation request + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa + * ::cuMemHostUnregister, + * ::cuMemHostGetFlags, + * ::cuMemHostGetDevicePointer, + * ::cudaHostRegister + */ +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, + unsigned int Flags); + +/** + * \brief Unregisters a memory range that was registered with cuMemHostRegister. + * + * Unmaps the memory range whose base address is specified by \p p, and makes + * it pageable again. + * + * The base address must be the same one specified to ::cuMemHostRegister(). + * + * \param p - Host pointer to memory to unregister + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED, + * \notefnerr + * + * \sa + * ::cuMemHostRegister, + * ::cudaHostUnregister + */ +CUresult CUDAAPI cuMemHostUnregister(void *p); + +/** + * \brief Copies memory + * + * Copies data between two pointers. + * \p dst and \p src are base pointers of the destination and source, + * respectively. \p ByteCount specifies the number of bytes to copy. Note that + * this function infers the type of the transfer (host to host, host to device, + * device to device, or device to host) from the pointer values. This function + * is only allowed in contexts which support unified addressing. + * + * \param dst - Destination unified virtual address space pointer + * \param src - Source unified virtual address space pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyToSymbol, + * ::cudaMemcpyFromSymbol + */ +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount); + +/** + * \brief Copies device memory between two contexts + * + * Copies from device memory in one context to device memory in another + * context. \p dstDevice is the base device pointer of the destination memory + * and \p dstContext is the destination context. \p srcDevice is the base + * device pointer of the source memory and \p srcContext is the source pointer. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param dstContext - Destination context + * \param srcDevice - Source device pointer + * \param srcContext - Source context + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuMemcpyDtoD, ::cuMemcpy3DPeer, ::cuMemcpyDtoDAsync, + * ::cuMemcpyPeerAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpyPeer + */ +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount); + +/** + * \brief Copies memory from Host to Device + * + * Copies from host memory to device memory. \p dstDevice and \p srcHost are + * the base addresses of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyToSymbol + */ +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount); + +/** + * \brief Copies memory from Device to Host + * + * Copies from device to host memory. \p dstHost and \p srcDevice specify the + * base pointers of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstHost - Destination host pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyFromSymbol + */ +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount); + +/** + * \brief Copies memory from Device to Device + * + * Copies from device memory to device memory. \p dstDevice and \p srcDevice + * are the base pointers of the destination and source, respectively. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyToSymbol, + * ::cudaMemcpyFromSymbol + */ +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount); + +/** + * \brief Copies memory from Device to Array + * + * Copies from device memory to a 1D CUDA array. \p dstArray and \p dstOffset + * specify the CUDA array handle and starting index of the destination data. + * \p srcDevice specifies the base pointer of the source. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyToArray + */ +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, size_t dstOffset, + CUdeviceptr srcDevice, size_t ByteCount); + +/** + * \brief Copies memory from Array to Device + * + * Copies from one 1D CUDA array to device memory. \p dstDevice specifies the + * base pointer of the destination and must be naturally aligned with the CUDA + * array elements. \p srcArray and \p srcOffset specify the CUDA array handle + * and the offset in bytes into the array where the copy is to begin. + * \p ByteCount specifies the number of bytes to copy and must be evenly + * divisible by the array element size. + * + * \param dstDevice - Destination device pointer + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyFromArray + */ +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, + size_t srcOffset, size_t ByteCount); + +/** + * \brief Copies memory from Host to Array + * + * Copies from host memory to a 1D CUDA array. \p dstArray and \p dstOffset + * specify the CUDA array handle and starting offset in bytes of the destination + * data. \p pSrc specifies the base address of the source. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyToArray + */ +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount); + +/** + * \brief Copies memory from Array to Host + * + * Copies from one 1D CUDA array to host memory. \p dstHost specifies the base + * pointer of the destination. \p srcArray and \p srcOffset specify the CUDA + * array handle and starting offset in bytes of the source data. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstHost - Destination device pointer + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyFromArray + */ +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, size_t srcOffset, + size_t ByteCount); + +/** + * \brief Copies memory from Array to Array + * + * Copies from one 1D CUDA array to another. \p dstArray and \p srcArray + * specify the handles of the destination and source CUDA arrays for the copy, + * respectively. \p dstOffset and \p srcOffset specify the destination and + * source offsets in bytes into the CUDA arrays. \p ByteCount is the number of + * bytes to be copied. The size of the elements in the CUDA arrays need not be + * the same format, but the elements must be the same size; and count must be + * evenly divisible by that size. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyArrayToArray + */ +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, size_t dstOffset, + CUarray srcArray, size_t srcOffset, + size_t ByteCount); + +/** + * \brief Copies memory for 2D arrays + * + * Perform a 2D memory copy according to the parameters specified in \p pCopy. + * The ::CUDA_MEMCPY2D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY2D_st { + unsigned int srcXInBytes, srcY; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; + + unsigned int dstXInBytes, dstY; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; + + unsigned int WidthInBytes; + unsigned int Height; + } CUDA_MEMCPY2D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost and ::srcPitch + * specify the (host) base address of the source data and the bytes per row to + * apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice and ::srcPitch + * specify the (device) base address of the source data and the bytes per row + * to apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice and ::srcPitch are + * ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice and ::dstPitch are + * ignored. + * + * - ::srcXInBytes and ::srcY specify the base address of the source data for + * the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+srcY*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+srcY*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - ::dstXInBytes and ::dstY specify the base address of the destination data + * for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+dstY*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+dstY*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes and ::Height specify the width (in bytes) and height of + * the 2D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * + * \par + * ::cuMemcpy2D() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). ::cuMemAllocPitch() passes back + * pitches that always work with ::cuMemcpy2D(). On intra-device memory copies + * (device to device, CUDA array to device, CUDA array to CUDA array), + * ::cuMemcpy2D() may fail for pitches not computed by ::cuMemAllocPitch(). + * ::cuMemcpy2DUnaligned() does not have this restriction, but may run + * significantly slower in the cases where ::cuMemcpy2D() would have returned + * an error code. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy2D, + * ::cudaMemcpy2DToArray, + * ::cudaMemcpy2DFromArray + */ +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy); + +/** + * \brief Copies memory for 2D arrays + * + * Perform a 2D memory copy according to the parameters specified in \p pCopy. + * The ::CUDA_MEMCPY2D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY2D_st { + unsigned int srcXInBytes, srcY; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; + unsigned int dstXInBytes, dstY; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; + unsigned int WidthInBytes; + unsigned int Height; + } CUDA_MEMCPY2D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost and ::srcPitch + * specify the (host) base address of the source data and the bytes per row to + * apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice and ::srcPitch + * specify the (device) base address of the source data and the bytes per row + * to apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice and ::srcPitch are + * ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice and ::dstPitch are + * ignored. + * + * - ::srcXInBytes and ::srcY specify the base address of the source data for + * the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+srcY*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+srcY*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - ::dstXInBytes and ::dstY specify the base address of the destination data + * for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+dstY*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+dstY*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes and ::Height specify the width (in bytes) and height of + * the 2D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * + * \par + * ::cuMemcpy2D() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). ::cuMemAllocPitch() passes back + * pitches that always work with ::cuMemcpy2D(). On intra-device memory copies + * (device to device, CUDA array to device, CUDA array to CUDA array), + * ::cuMemcpy2D() may fail for pitches not computed by ::cuMemAllocPitch(). + * ::cuMemcpy2DUnaligned() does not have this restriction, but may run + * significantly slower in the cases where ::cuMemcpy2D() would have returned + * an error code. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy2D, + * ::cudaMemcpy2DToArray, + * ::cudaMemcpy2DFromArray + */ +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy); + +/** + * \brief Copies memory for 3D arrays + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. The ::CUDA_MEMCPY3D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY3D_st { + + unsigned int srcXInBytes, srcY, srcZ; + unsigned int srcLOD; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; // ignored when src is array + unsigned int srcHeight; // ignored when src is array; may be 0 + if Depth==1 + + unsigned int dstXInBytes, dstY, dstZ; + unsigned int dstLOD; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; // ignored when dst is array + unsigned int dstHeight; // ignored when dst is array; may be 0 + if Depth==1 + + unsigned int WidthInBytes; + unsigned int Height; + unsigned int Depth; + } CUDA_MEMCPY3D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost, ::srcPitch and + * ::srcHeight specify the (host) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice, ::srcPitch and + * ::srcHeight specify the (device) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice, ::srcPitch and + * ::srcHeight are ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data, the bytes per row, + * and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data, the bytes per + * row, and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice, ::dstPitch and + * ::dstHeight are ignored. + * + * - ::srcXInBytes, ::srcY and ::srcZ specify the base address of the source + * data for the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+(srcZ*srcHeight+srcY)*srcPitch + + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+(srcZ*srcHeight+srcY)*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - dstXInBytes, ::dstY and ::dstZ specify the base address of the + * destination data for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+(dstZ*dstHeight+dstY)*dstPitch + + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+(dstZ*dstHeight+dstY)*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes, ::Height and ::Depth specify the width (in bytes), height + * and depth of the 3D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcHeight must be greater than or equal to ::Height + + * ::srcY, and ::dstHeight must be greater than or equal to ::Height + ::dstY. + * + * \par + * ::cuMemcpy3D() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). + * + * The ::srcLOD and ::dstLOD members of the ::CUDA_MEMCPY3D structure must be + * set to 0. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy3D + */ +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy); + +/** + * \brief Copies memory between contexts + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. See the definition of the ::CUDA_MEMCPY3D_PEER structure + * for documentation of its parameters. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuMemcpyDtoD, ::cuMemcpyPeer, ::cuMemcpyDtoDAsync, ::cuMemcpyPeerAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpy3DPeer + */ +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy); + +/** + * \brief Copies memory asynchronously + * + * Copies data between two pointers. + * \p dst and \p src are base pointers of the destination and source, + * respectively. \p ByteCount specifies the number of bytes to copy. Note that + * this function infers the type of the transfer (host to host, host to device, + * device to device, or device to host) from the pointer values. This function + * is only allowed in contexts which support unified addressing. + * + * \param dst - Destination unified virtual address space pointer + * \param src - Source unified virtual address space pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyToSymbolAsync, + * ::cudaMemcpyFromSymbolAsync + */ +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, + size_t ByteCount, CUstream hStream); + +/** + * \brief Copies device memory between two contexts asynchronously. + * + * Copies from device memory in one context to device memory in another + * context. \p dstDevice is the base device pointer of the destination memory + * and \p dstContext is the destination context. \p srcDevice is the base + * device pointer of the source memory and \p srcContext is the source pointer. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param dstContext - Destination context + * \param srcDevice - Source device pointer + * \param srcContext - Source context + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpyDtoD, ::cuMemcpyPeer, ::cuMemcpy3DPeer, ::cuMemcpyDtoDAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpyPeerAsync + */ +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Host to Device + * + * Copies from host memory to device memory. \p dstDevice and \p srcHost are + * the base addresses of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyToSymbolAsync + */ +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Device to Host + * + * Copies from device to host memory. \p dstHost and \p srcDevice specify the + * base pointers of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstHost - Destination host pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyFromSymbolAsync + */ +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Device to Device + * + * Copies from device memory to device memory. \p dstDevice and \p srcDevice + * are the base pointers of the destination and source, respectively. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyToSymbolAsync, + * ::cudaMemcpyFromSymbolAsync + */ +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Host to Array + * + * Copies from host memory to a 1D CUDA array. \p dstArray and \p dstOffset + * specify the CUDA array handle and starting offset in bytes of the + * destination data. \p srcHost specifies the base address of the source. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyToArrayAsync + */ +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + CUstream hStream); + +/** + * \brief Copies memory from Array to Host + * + * Copies from one 1D CUDA array to host memory. \p dstHost specifies the base + * pointer of the destination. \p srcArray and \p srcOffset specify the CUDA + * array handle and starting offset in bytes of the source data. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstHost - Destination pointer + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyFromArrayAsync + */ +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount, + CUstream hStream); + +/** + * \brief Copies memory for 2D arrays + * + * Perform a 2D memory copy according to the parameters specified in \p pCopy. + * The ::CUDA_MEMCPY2D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY2D_st { + unsigned int srcXInBytes, srcY; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; + unsigned int dstXInBytes, dstY; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; + unsigned int WidthInBytes; + unsigned int Height; + } CUDA_MEMCPY2D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost and ::srcPitch + * specify the (host) base address of the source data and the bytes per row to + * apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice and ::srcPitch + * specify the (device) base address of the source data and the bytes per row + * to apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice and ::srcPitch are + * ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice and ::dstPitch are + * ignored. + * + * - ::srcXInBytes and ::srcY specify the base address of the source data for + * the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+srcY*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+srcY*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - ::dstXInBytes and ::dstY specify the base address of the destination data + * for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+dstY*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+dstY*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes and ::Height specify the width (in bytes) and height of + * the 2D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcHeight must be greater than or equal to ::Height + + * ::srcY, and ::dstHeight must be greater than or equal to ::Height + ::dstY. + * + * \par + * ::cuMemcpy2DAsync() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). ::cuMemAllocPitch() passes back + * pitches that always work with ::cuMemcpy2D(). On intra-device memory copies + * (device to device, CUDA array to device, CUDA array to CUDA array), + * ::cuMemcpy2DAsync() may fail for pitches not computed by ::cuMemAllocPitch(). + * + * \param pCopy - Parameters for the memory copy + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpy2DAsync, + * ::cudaMemcpy2DToArrayAsync, + * ::cudaMemcpy2DFromArrayAsync + */ +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream); + +/** + * \brief Copies memory for 3D arrays + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. The ::CUDA_MEMCPY3D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY3D_st { + + unsigned int srcXInBytes, srcY, srcZ; + unsigned int srcLOD; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; // ignored when src is array + unsigned int srcHeight; // ignored when src is array; may be 0 + if Depth==1 + + unsigned int dstXInBytes, dstY, dstZ; + unsigned int dstLOD; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; // ignored when dst is array + unsigned int dstHeight; // ignored when dst is array; may be 0 + if Depth==1 + + unsigned int WidthInBytes; + unsigned int Height; + unsigned int Depth; + } CUDA_MEMCPY3D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost, ::srcPitch and + * ::srcHeight specify the (host) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice, ::srcPitch and + * ::srcHeight specify the (device) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice, ::srcPitch and + * ::srcHeight are ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data, the bytes per row, + * and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data, the bytes per + * row, and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice, ::dstPitch and + * ::dstHeight are ignored. + * + * - ::srcXInBytes, ::srcY and ::srcZ specify the base address of the source + * data for the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+(srcZ*srcHeight+srcY)*srcPitch + + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+(srcZ*srcHeight+srcY)*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - dstXInBytes, ::dstY and ::dstZ specify the base address of the + * destination data for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+(dstZ*dstHeight+dstY)*dstPitch + + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+(dstZ*dstHeight+dstY)*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes, ::Height and ::Depth specify the width (in bytes), height + * and depth of the 3D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcHeight must be greater than or equal to ::Height + + * ::srcY, and ::dstHeight must be greater than or equal to ::Height + ::dstY. + * + * \par + * ::cuMemcpy3DAsync() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). + * + * The ::srcLOD and ::dstLOD members of the ::CUDA_MEMCPY3D structure must be + * set to 0. + * + * \param pCopy - Parameters for the memory copy + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpy3DAsync + */ +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream); + +/** + * \brief Copies memory between contexts asynchronously. + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. See the definition of the ::CUDA_MEMCPY3D_PEER structure + * for documentation of its parameters. + * + * \param pCopy - Parameters for the memory copy + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpyDtoD, ::cuMemcpyPeer, ::cuMemcpyDtoDAsync, ::cuMemcpyPeerAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpy3DPeerAsync + */ +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, + CUstream hStream); + +/** + * \brief Initializes device memory + * + * Sets the memory range of \p N 8-bit values to the specified value + * \p uc. + * + * \param dstDevice - Destination device pointer + * \param uc - Value to set + * \param N - Number of elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset + */ +CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, size_t N); + +/** + * \brief Initializes device memory + * + * Sets the memory range of \p N 16-bit values to the specified value + * \p us. The \p dstDevice pointer must be two byte aligned. + * + * \param dstDevice - Destination device pointer + * \param us - Value to set + * \param N - Number of elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset + */ +CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, + size_t N); + +/** + * \brief Initializes device memory + * + * Sets the memory range of \p N 32-bit values to the specified value + * \p ui. The \p dstDevice pointer must be four byte aligned. + * + * \param dstDevice - Destination device pointer + * \param ui - Value to set + * \param N - Number of elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32Async, + * ::cudaMemset + */ +CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N); + +/** + * \brief Initializes device memory + * + * Sets the 2D memory range of \p Width 8-bit values to the specified value + * \p uc. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is + * 1) \param uc - Value to set \param Width - Width of row \param + * Height - Number of rows + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2D + */ +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, size_t Height); + +/** + * \brief Initializes device memory + * + * Sets the 2D memory range of \p Width 16-bit values to the specified value + * \p us. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be two byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is + * 1) \param us - Value to set \param Width - Width of row \param + * Height - Number of rows + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2D + */ +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, size_t Height); + +/** + * \brief Initializes device memory + * + * Sets the 2D memory range of \p Width 32-bit values to the specified value + * \p ui. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be four byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is + * 1) \param ui - Value to set \param Width - Width of row \param + * Height - Number of rows + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2D + */ +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, size_t Height); + +/** + * \brief Sets device memory + * + * Sets the memory range of \p N 8-bit values to the specified value + * \p uc. + * + * \param dstDevice - Destination device pointer + * \param uc - Value to set + * \param N - Number of elements + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemsetAsync + */ +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, + size_t N, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the memory range of \p N 16-bit values to the specified value + * \p us. The \p dstDevice pointer must be two byte aligned. + * + * \param dstDevice - Destination device pointer + * \param us - Value to set + * \param N - Number of elements + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemsetAsync + */ +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, + size_t N, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the memory range of \p N 32-bit values to the specified value + * \p ui. The \p dstDevice pointer must be four byte aligned. + * + * \param dstDevice - Destination device pointer + * \param ui - Value to set + * \param N - Number of elements + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, + * ::cudaMemsetAsync + */ +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, + size_t N, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the 2D memory range of \p Width 8-bit values to the specified value + * \p uc. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is + * 1) \param uc - Value to set \param Width - Width of row \param + * Height - Number of rows \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2DAsync + */ +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, + size_t Height, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the 2D memory range of \p Width 16-bit values to the specified value + * \p us. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be two byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is + * 1) \param us - Value to set \param Width - Width of row \param + * Height - Number of rows \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2DAsync + */ +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the 2D memory range of \p Width 32-bit values to the specified value + * \p ui. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be four byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is + * 1) \param ui - Value to set \param Width - Width of row \param + * Height - Number of rows \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2DAsync + */ +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, + size_t Height, CUstream hStream); + +/** + * \brief Creates a 1D or 2D CUDA array + * + * Creates a CUDA array according to the ::CUDA_ARRAY_DESCRIPTOR structure + * \p pAllocateArray and returns a handle to the new CUDA array in \p *pHandle. + * The ::CUDA_ARRAY_DESCRIPTOR is defined as: + * + * \code + typedef struct { + unsigned int Width; + unsigned int Height; + CUarray_format Format; + unsigned int NumChannels; + } CUDA_ARRAY_DESCRIPTOR; + * \endcode + * where: + * + * - \p Width, and \p Height are the width, and height of the CUDA array (in + * elements); the CUDA array is one-dimensional if height is 0, two-dimensional + * otherwise; + * - ::Format specifies the format of the elements; ::CUarray_format is + * defined as: + * \code + typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, + CU_AD_FORMAT_SIGNED_INT8 = 0x08, + CU_AD_FORMAT_SIGNED_INT16 = 0x09, + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, + CU_AD_FORMAT_HALF = 0x10, + CU_AD_FORMAT_FLOAT = 0x20 + } CUarray_format; + * \endcode + * - \p NumChannels specifies the number of packed components per CUDA array + * element; it may be 1, 2, or 4; + * + * Here are examples of CUDA array descriptions: + * + * Description for a CUDA array of 2048 floats: + * \code + CUDA_ARRAY_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 2048; + desc.Height = 1; + * \endcode + * + * Description for a 64 x 64 CUDA array of floats: + * \code + CUDA_ARRAY_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 64; + desc.Height = 64; + * \endcode + * + * Description for a \p width x \p height CUDA array of 64-bit, 4x16-bit + * float16's: + * \code + CUDA_ARRAY_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_HALF; + desc.NumChannels = 4; + desc.Width = width; + desc.Height = height; + * \endcode + * + * Description for a \p width x \p height CUDA array of 16-bit elements, each + * of which is two 8-bit unsigned chars: + * \code + CUDA_ARRAY_DESCRIPTOR arrayDesc; + desc.Format = CU_AD_FORMAT_UNSIGNED_INT8; + desc.NumChannels = 2; + desc.Width = width; + desc.Height = height; + * \endcode + * + * \param pHandle - Returned array + * \param pAllocateArray - Array descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMallocArray + */ +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, + const CUDA_ARRAY_DESCRIPTOR *pAllocateArray); + +/** + * \brief Get a 1D or 2D CUDA array descriptor + * + * Returns in \p *pArrayDescriptor a descriptor containing information on the + * format and dimensions of the CUDA array \p hArray. It is useful for + * subroutines that have been passed a CUDA array, but need to know the CUDA + * array parameters for validation or other purposes. + * + * \param pArrayDescriptor - Returned array descriptor + * \param hArray - Array to get descriptor of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaArrayGetInfo + */ +CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, + CUarray hArray); + +/** + * \brief Returns the layout properties of a sparse CUDA array + * + * Returns the layout properties of a sparse CUDA array in \p sparseProperties + * If the CUDA array is not allocated with flag ::CUDA_ARRAY3D_SPARSE + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * If the returned value in ::CUDA_ARRAY_SPARSE_PROPERTIES::flags contains + * ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL, then + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize represents the total size of the + * array. Otherwise, it will be zero. Also, the returned value in + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailFirstLevel is always zero. Note that + * the \p array must have been allocated using ::cuArrayCreate or + * ::cuArray3DCreate. For CUDA arrays obtained using ::cuMipmappedArrayGetLevel, + * ::CUDA_ERROR_INVALID_VALUE will be returned. Instead, + * ::cuMipmappedArrayGetSparseProperties must be used to obtain the sparse + * properties of the entire CUDA mipmapped array to which \p array belongs to. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] sparseProperties - Pointer to ::CUDA_ARRAY_SPARSE_PROPERTIES + * \param[in] array - CUDA array to get the sparse properties of + * \sa ::cuMipmappedArrayGetSparseProperties, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuArrayGetSparseProperties( + CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUarray array); + +/** + * \brief Returns the layout properties of a sparse CUDA mipmapped array + * + * Returns the sparse array layout properties in \p sparseProperties + * If the CUDA mipmapped array is not allocated with flag ::CUDA_ARRAY3D_SPARSE + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * For non-layered CUDA mipmapped arrays, + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize returns the size of the mip tail + * region. The mip tail region includes all mip levels whose width, height or + * depth is less than that of the tile. For layered CUDA mipmapped arrays, if + * ::CUDA_ARRAY_SPARSE_PROPERTIES::flags contains + * ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL, then + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize specifies the size of the mip + * tail of all layers combined. Otherwise, + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize specifies mip tail size per + * layer. The returned value of + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailFirstLevel is valid only if + * ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize is non-zero. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] sparseProperties - Pointer to ::CUDA_ARRAY_SPARSE_PROPERTIES + * \param[in] mipmap - CUDA mipmapped array to get the sparse properties of + * \sa ::cuArrayGetSparseProperties, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuMipmappedArrayGetSparseProperties( + CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUmipmappedArray mipmap); + +/** + * \brief Returns the memory requirements of a CUDA array + * + * Returns the memory requirements of a CUDA array in \p memoryRequirements + * If the CUDA array is not allocated with flag ::CUDA_ARRAY3D_DEFERRED_MAPPING + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::size + * represents the total size of the CUDA array. + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::alignment + * represents the alignment necessary for mapping the CUDA array. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] memoryRequirements - Pointer to ::CUDA_ARRAY_MEMORY_REQUIREMENTS + * \param[in] array - CUDA array to get the memory requirements of + * \param[in] device - Device to get the memory requirements for + * \sa ::cuMipmappedArrayGetMemoryRequirements, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI +cuArrayGetMemoryRequirements(CUDA_ARRAY_MEMORY_REQUIREMENTS *memoryRequirements, + CUarray array, CUdevice device); + +/** + * \brief Returns the memory requirements of a CUDA mipmapped array + * + * Returns the memory requirements of a CUDA mipmapped array in \p + * memoryRequirements If the CUDA mipmapped array is not allocated with flag + * ::CUDA_ARRAY3D_DEFERRED_MAPPING + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::size + * represents the total size of the CUDA mipmapped array. + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::alignment + * represents the alignment necessary for mapping the CUDA mipmapped + * array. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] memoryRequirements - Pointer to ::CUDA_ARRAY_MEMORY_REQUIREMENTS + * \param[in] mipmap - CUDA mipmapped array to get the memory requirements of + * \param[in] device - Device to get the memory requirements for + * \sa ::cuArrayGetMemoryRequirements, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuMipmappedArrayGetMemoryRequirements( + CUDA_ARRAY_MEMORY_REQUIREMENTS *memoryRequirements, CUmipmappedArray mipmap, + CUdevice device); + +/** + * \brief Gets a CUDA array plane from a CUDA array + * + * Returns in \p pPlaneArray a CUDA array that represents a single format plane + * of the CUDA array \p hArray. + * + * If \p planeIdx is greater than the maximum number of planes in this array or + * if the array does not have a multi-planar format e.g: ::CU_AD_FORMAT_NV12, + * then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Note that if the \p hArray has format ::CU_AD_FORMAT_NV12, then passing in 0 + * for \p planeIdx returns a CUDA array of the same size as \p hArray but with + * one channel and ::CU_AD_FORMAT_UNSIGNED_INT8 as its format. If 1 is passed + * for \p planeIdx, then the returned CUDA array has half the height and width + * of \p hArray with two channels and ::CU_AD_FORMAT_UNSIGNED_INT8 as its + * format. + * + * \param pPlaneArray - Returned CUDA array referenced by the \p planeIdx + * \param hArray - Multiplanar CUDA array + * \param planeIdx - Plane index + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::cuArrayCreate, + * ::cudaArrayGetPlane + */ +CUresult CUDAAPI cuArrayGetPlane(CUarray *pPlaneArray, CUarray hArray, + unsigned int planeIdx); + +/** + * \brief Destroys a CUDA array + * + * Destroys the CUDA array \p hArray. + * + * \param hArray - Array to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ARRAY_IS_MAPPED, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaFreeArray + */ +CUresult CUDAAPI cuArrayDestroy(CUarray hArray); + +/** + * \brief Creates a 3D CUDA array + * + * Creates a CUDA array according to the ::CUDA_ARRAY3D_DESCRIPTOR structure + * \p pAllocateArray and returns a handle to the new CUDA array in \p *pHandle. + * The ::CUDA_ARRAY3D_DESCRIPTOR is defined as: + * + * \code + typedef struct { + unsigned int Width; + unsigned int Height; + unsigned int Depth; + CUarray_format Format; + unsigned int NumChannels; + unsigned int Flags; + } CUDA_ARRAY3D_DESCRIPTOR; + * \endcode + * where: + * + * - \p Width, \p Height, and \p Depth are the width, height, and depth of the + * CUDA array (in elements); the following types of CUDA arrays can be + allocated: + * - A 1D array is allocated if \p Height and \p Depth extents are both + zero. + * - A 2D array is allocated if only \p Depth extent is zero. + * - A 3D array is allocated if all three extents are non-zero. + * - A 1D layered CUDA array is allocated if only \p Height is zero and the + * ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 1D array. The + number + * of layers is determined by the depth extent. + * - A 2D layered CUDA array is allocated if all three extents are non-zero + and + * the ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 2D array. The + number + * of layers is determined by the depth extent. + * - A cubemap CUDA array is allocated if all three extents are non-zero and + the + * ::CUDA_ARRAY3D_CUBEMAP flag is set. \p Width must be equal to \p + Height, and + * \p Depth must be six. A cubemap is a special type of 2D layered CUDA + array, + * where the six layers represent the six faces of a cube. The order of + the six + * layers in memory is the same as that listed in ::CUarray_cubemap_face. + * - A cubemap layered CUDA array is allocated if all three extents are + non-zero, + * and both, ::CUDA_ARRAY3D_CUBEMAP and ::CUDA_ARRAY3D_LAYERED flags are + set. + * \p Width must be equal to \p Height, and \p Depth must be a multiple of + six. + * A cubemap layered CUDA array is a special type of 2D layered CUDA array + that + * consists of a collection of cubemaps. The first six layers represent + the first + * cubemap, the next six layers form the second cubemap, and so on. + * + * - ::Format specifies the format of the elements; ::CUarray_format is + * defined as: + * \code + typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, + CU_AD_FORMAT_SIGNED_INT8 = 0x08, + CU_AD_FORMAT_SIGNED_INT16 = 0x09, + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, + CU_AD_FORMAT_HALF = 0x10, + CU_AD_FORMAT_FLOAT = 0x20 + } CUarray_format; + * \endcode + * + * - \p NumChannels specifies the number of packed components per CUDA array + * element; it may be 1, 2, or 4; + * + * - ::Flags may be set to + * - ::CUDA_ARRAY3D_LAYERED to enable creation of layered CUDA arrays. If this + flag is set, + * \p Depth specifies the number of layers, not the depth of a 3D array. + * - ::CUDA_ARRAY3D_SURFACE_LDST to enable surface references to be bound to + the CUDA array. + * If this flag is not set, ::cuSurfRefSetArray will fail when attempting to + bind the CUDA array + * to a surface reference. + * - ::CUDA_ARRAY3D_CUBEMAP to enable creation of cubemaps. If this flag is + set, \p Width must be + * equal to \p Height, and \p Depth must be six. If the + ::CUDA_ARRAY3D_LAYERED flag is also set, + * then \p Depth must be a multiple of six. + * - ::CUDA_ARRAY3D_TEXTURE_GATHER to indicate that the CUDA array will be + used for texture gather. + * Texture gather can only be performed on 2D CUDA arrays. + * + * \p Width, \p Height and \p Depth must meet certain size requirements as + listed in the following table. + * All values are specified in elements. Note that for brevity's sake, the full + name of the device attribute + * is not specified. For ex., TEXTURE1D_WIDTH refers to the device attribute + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH. + * + * Note that 2D CUDA arrays have different size requirements if the + ::CUDA_ARRAY3D_TEXTURE_GATHER flag + * is set. \p Width and \p Height must not be greater than + ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH + * and ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT respectively, in + that case. + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
CUDA array typeValid extents that must always be met
{(width range in elements), + (height range), + * (depth range)}
Valid extents with CUDA_ARRAY3D_SURFACE_LDST set
+ * {(width range in elements), (height range), (depth range)}
1D{ (1,TEXTURE1D_WIDTH), 0, 0 }{ (1,SURFACE1D_WIDTH), 0, 0 }
2D{ (1,TEXTURE2D_WIDTH), (1,TEXTURE2D_HEIGHT), 0 }{ (1,SURFACE2D_WIDTH), (1,SURFACE2D_HEIGHT), 0 }
3D{ (1,TEXTURE3D_WIDTH), (1,TEXTURE3D_HEIGHT), (1,TEXTURE3D_DEPTH) } + *
OR
{ (1,TEXTURE3D_WIDTH_ALTERNATE), (1,TEXTURE3D_HEIGHT_ALTERNATE), + * (1,TEXTURE3D_DEPTH_ALTERNATE) }
{ (1,SURFACE3D_WIDTH), (1,SURFACE3D_HEIGHT), + * (1,SURFACE3D_DEPTH) }
1D Layered{ (1,TEXTURE1D_LAYERED_WIDTH), 0, + * (1,TEXTURE1D_LAYERED_LAYERS) }{ (1,SURFACE1D_LAYERED_WIDTH), 0, + * (1,SURFACE1D_LAYERED_LAYERS) }
2D Layered{ (1,TEXTURE2D_LAYERED_WIDTH), (1,TEXTURE2D_LAYERED_HEIGHT), + * (1,TEXTURE2D_LAYERED_LAYERS) }{ (1,SURFACE2D_LAYERED_WIDTH), (1,SURFACE2D_LAYERED_HEIGHT), + * (1,SURFACE2D_LAYERED_LAYERS) }
Cubemap{ (1,TEXTURECUBEMAP_WIDTH), (1,TEXTURECUBEMAP_WIDTH), 6 + }{ (1,SURFACECUBEMAP_WIDTH), + * (1,SURFACECUBEMAP_WIDTH), 6 }
Cubemap Layered{ (1,TEXTURECUBEMAP_LAYERED_WIDTH), + (1,TEXTURECUBEMAP_LAYERED_WIDTH), + * (1,TEXTURECUBEMAP_LAYERED_LAYERS) }{ (1,SURFACECUBEMAP_LAYERED_WIDTH), + (1,SURFACECUBEMAP_LAYERED_WIDTH), + * (1,SURFACECUBEMAP_LAYERED_LAYERS) }
+ * + * Here are examples of CUDA array descriptions: + * + * Description for a CUDA array of 2048 floats: + * \code + CUDA_ARRAY3D_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 2048; + desc.Height = 0; + desc.Depth = 0; + * \endcode + * + * Description for a 64 x 64 CUDA array of floats: + * \code + CUDA_ARRAY3D_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 64; + desc.Height = 64; + desc.Depth = 0; + * \endcode + * + * Description for a \p width x \p height x \p depth CUDA array of 64-bit, + * 4x16-bit float16's: + * \code + CUDA_ARRAY3D_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_HALF; + desc.NumChannels = 4; + desc.Width = width; + desc.Height = height; + desc.Depth = depth; + * \endcode + * + * \param pHandle - Returned array + * \param pAllocateArray - 3D array descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMalloc3DArray + */ +CUresult CUDAAPI cuArray3DCreate(CUarray *pHandle, + const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray); + +/** + * \brief Get a 3D CUDA array descriptor + * + * Returns in \p *pArrayDescriptor a descriptor containing information on the + * format and dimensions of the CUDA array \p hArray. It is useful for + * subroutines that have been passed a CUDA array, but need to know the CUDA + * array parameters for validation or other purposes. + * + * This function may be called on 1D and 2D arrays, in which case the \p Height + * and/or \p Depth members of the descriptor struct will be set to 0. + * + * \param pArrayDescriptor - Returned 3D array descriptor + * \param hArray - 3D array to get descriptor of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaArrayGetInfo + */ +CUresult CUDAAPI cuArray3DGetDescriptor( + CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray); + +/** + * \brief Creates a CUDA mipmapped array + * + * Creates a CUDA mipmapped array according to the ::CUDA_ARRAY3D_DESCRIPTOR + structure + * \p pMipmappedArrayDesc and returns a handle to the new CUDA mipmapped array + in \p *pHandle. + * \p numMipmapLevels specifies the number of mipmap levels to be allocated. + This value is + * clamped to the range [1, 1 + floor(log2(max(width, height, depth)))]. + * + * The ::CUDA_ARRAY3D_DESCRIPTOR is defined as: + * + * \code + typedef struct { + unsigned int Width; + unsigned int Height; + unsigned int Depth; + CUarray_format Format; + unsigned int NumChannels; + unsigned int Flags; + } CUDA_ARRAY3D_DESCRIPTOR; + * \endcode + * where: + * + * - \p Width, \p Height, and \p Depth are the width, height, and depth of the + * CUDA array (in elements); the following types of CUDA arrays can be + allocated: + * - A 1D mipmapped array is allocated if \p Height and \p Depth extents are + both zero. + * - A 2D mipmapped array is allocated if only \p Depth extent is zero. + * - A 3D mipmapped array is allocated if all three extents are non-zero. + * - A 1D layered CUDA mipmapped array is allocated if only \p Height is + zero and the + * ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 1D array. The + number + * of layers is determined by the depth extent. + * - A 2D layered CUDA mipmapped array is allocated if all three extents are + non-zero and + * the ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 2D array. The + number + * of layers is determined by the depth extent. + * - A cubemap CUDA mipmapped array is allocated if all three extents are + non-zero and the + * ::CUDA_ARRAY3D_CUBEMAP flag is set. \p Width must be equal to \p + Height, and + * \p Depth must be six. A cubemap is a special type of 2D layered CUDA + array, + * where the six layers represent the six faces of a cube. The order of + the six + * layers in memory is the same as that listed in ::CUarray_cubemap_face. + * - A cubemap layered CUDA mipmapped array is allocated if all three + extents are non-zero, + * and both, ::CUDA_ARRAY3D_CUBEMAP and ::CUDA_ARRAY3D_LAYERED flags are + set. + * \p Width must be equal to \p Height, and \p Depth must be a multiple of + six. + * A cubemap layered CUDA array is a special type of 2D layered CUDA array + that + * consists of a collection of cubemaps. The first six layers represent + the first + * cubemap, the next six layers form the second cubemap, and so on. + * + * - ::Format specifies the format of the elements; ::CUarray_format is + * defined as: + * \code + typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, + CU_AD_FORMAT_SIGNED_INT8 = 0x08, + CU_AD_FORMAT_SIGNED_INT16 = 0x09, + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, + CU_AD_FORMAT_HALF = 0x10, + CU_AD_FORMAT_FLOAT = 0x20 + } CUarray_format; + * \endcode + * + * - \p NumChannels specifies the number of packed components per CUDA array + * element; it may be 1, 2, or 4; + * + * - ::Flags may be set to + * - ::CUDA_ARRAY3D_LAYERED to enable creation of layered CUDA mipmapped + arrays. If this flag is set, + * \p Depth specifies the number of layers, not the depth of a 3D array. + * - ::CUDA_ARRAY3D_SURFACE_LDST to enable surface references to be bound to + individual mipmap levels of + * the CUDA mipmapped array. If this flag is not set, ::cuSurfRefSetArray + will fail when attempting to + * bind a mipmap level of the CUDA mipmapped array to a surface reference. + * - ::CUDA_ARRAY3D_CUBEMAP to enable creation of mipmapped cubemaps. If this + flag is set, \p Width must be + * equal to \p Height, and \p Depth must be six. If the + ::CUDA_ARRAY3D_LAYERED flag is also set, + * then \p Depth must be a multiple of six. + * - ::CUDA_ARRAY3D_TEXTURE_GATHER to indicate that the CUDA mipmapped array + will be used for texture gather. + * Texture gather can only be performed on 2D CUDA mipmapped arrays. + * + * \p Width, \p Height and \p Depth must meet certain size requirements as + listed in the following table. + * All values are specified in elements. Note that for brevity's sake, the full + name of the device attribute + * is not specified. For ex., TEXTURE1D_MIPMAPPED_WIDTH refers to the device + attribute + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH. + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
CUDA array typeValid extents that must always be met
{(width range in elements), + (height range), + * (depth range)}
Valid extents with CUDA_ARRAY3D_SURFACE_LDST set
+ * {(width range in elements), (height range), (depth range)}
1D{ (1,TEXTURE1D_MIPMAPPED_WIDTH), 0, 0 }{ (1,SURFACE1D_WIDTH), 0, 0 }
2D{ (1,TEXTURE2D_MIPMAPPED_WIDTH), (1,TEXTURE2D_MIPMAPPED_HEIGHT), 0 + }{ (1,SURFACE2D_WIDTH), (1,SURFACE2D_HEIGHT), 0 }
3D{ (1,TEXTURE3D_WIDTH), (1,TEXTURE3D_HEIGHT), (1,TEXTURE3D_DEPTH) } + *
OR
{ (1,TEXTURE3D_WIDTH_ALTERNATE), (1,TEXTURE3D_HEIGHT_ALTERNATE), + * (1,TEXTURE3D_DEPTH_ALTERNATE) }
{ (1,SURFACE3D_WIDTH), (1,SURFACE3D_HEIGHT), + * (1,SURFACE3D_DEPTH) }
1D Layered{ (1,TEXTURE1D_LAYERED_WIDTH), 0, + * (1,TEXTURE1D_LAYERED_LAYERS) }{ (1,SURFACE1D_LAYERED_WIDTH), 0, + * (1,SURFACE1D_LAYERED_LAYERS) }
2D Layered{ (1,TEXTURE2D_LAYERED_WIDTH), (1,TEXTURE2D_LAYERED_HEIGHT), + * (1,TEXTURE2D_LAYERED_LAYERS) }{ (1,SURFACE2D_LAYERED_WIDTH), (1,SURFACE2D_LAYERED_HEIGHT), + * (1,SURFACE2D_LAYERED_LAYERS) }
Cubemap{ (1,TEXTURECUBEMAP_WIDTH), (1,TEXTURECUBEMAP_WIDTH), 6 + }{ (1,SURFACECUBEMAP_WIDTH), + * (1,SURFACECUBEMAP_WIDTH), 6 }
Cubemap Layered{ (1,TEXTURECUBEMAP_LAYERED_WIDTH), + (1,TEXTURECUBEMAP_LAYERED_WIDTH), + * (1,TEXTURECUBEMAP_LAYERED_LAYERS) }{ (1,SURFACECUBEMAP_LAYERED_WIDTH), + (1,SURFACECUBEMAP_LAYERED_WIDTH), + * (1,SURFACECUBEMAP_LAYERED_LAYERS) }
+ * + * + * \param pHandle - Returned mipmapped array + * \param pMipmappedArrayDesc - mipmapped array descriptor + * \param numMipmapLevels - Number of mipmap levels + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cuMipmappedArrayDestroy, + * ::cuMipmappedArrayGetLevel, + * ::cuArrayCreate, + * ::cudaMallocMipmappedArray + */ +CUresult CUDAAPI +cuMipmappedArrayCreate(CUmipmappedArray *pHandle, + const CUDA_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels); + +/** + * \brief Gets a mipmap level of a CUDA mipmapped array + * + * Returns in \p *pLevelArray a CUDA array that represents a single mipmap level + * of the CUDA mipmapped array \p hMipmappedArray. + * + * If \p level is greater than the maximum number of levels in this mipmapped + * array, + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param pLevelArray - Returned mipmap level CUDA array + * \param hMipmappedArray - CUDA mipmapped array + * \param level - Mipmap level + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::cuMipmappedArrayCreate, + * ::cuMipmappedArrayDestroy, + * ::cuArrayCreate, + * ::cudaGetMipmappedArrayLevel + */ +CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, + CUmipmappedArray hMipmappedArray, + unsigned int level); + +/** + * \brief Destroys a CUDA mipmapped array + * + * Destroys the CUDA mipmapped array \p hMipmappedArray. + * + * \param hMipmappedArray - Mipmapped array to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ARRAY_IS_MAPPED, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * \notefnerr + * + * \sa + * ::cuMipmappedArrayCreate, + * ::cuMipmappedArrayGetLevel, + * ::cuArrayCreate, + * ::cudaFreeMipmappedArray + */ +CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); + +/** + * \brief Retrieve handle for an address range + * + * Get a handle of the specified type to an address range. The address range + * must have been obtained by a prior call to either ::cuMemAlloc or + * ::cuMemAddressReserve. If the address range was obtained via + * ::cuMemAddressReserve, it must also be fully mapped via ::cuMemMap. The + * address range must have been obtained by a prior call to either + * ::cuMemAllocHost or + * ::cuMemHostAlloc on Tegra. + * + * Users must ensure the \p dptr and \p size are aligned to the host page size. + * + * When requesting CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, + * users are expected to query for dma_buf support for the platform + * by using ::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED device attribute before + * calling this API. The \p handle will be interpreted as a pointer to an + * integer to store the dma_buf file descriptor. Users must ensure the entire + * address range is backed and mapped when the address range is allocated by + * ::cuMemAddressReserve. All the physical allocations backing the address range + * must be resident on the same device and have identical allocation properties. + * Users are also expected to retrieve a new handle every time the underlying + * physical allocation(s) corresponding to a previously queried VA range are + * changed. + * + * \param[out] handle - Pointer to the location where the returned handle + * will be stored. \param[in] dptr - Pointer to a valid CUDA device + * allocation. Must be aligned to host page size. \param[in] size - + * Length of the address range. Must be aligned to host page size. \param[in] + * handleType - Type of handle requested (defines type and size of the \p + * handle output parameter) \param[in] flags - Reserved, must be zero + * + * \return + * CUDA_SUCCESS + * CUDA_ERROR_INVALID_VALUE + * CUDA_ERROR_NOT_SUPPORTED + */ +CUresult CUDAAPI cuMemGetHandleForAddressRange(void *handle, CUdeviceptr dptr, + size_t size, + CUmemRangeHandleType handleType, + unsigned long long flags); + +/** @} */ /* END CUDA_MEM */ + +/** + * \defgroup CUDA_VA Virtual Memory Management + * + * ___MANBRIEF___ virtual memory management functions of the low-level CUDA + * driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the virtual memory management functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Allocate an address range reservation. + * + * Reserves a virtual address range based on the given parameters, giving + * the starting address of the range in \p ptr. This API requires a system that + * supports UVA. The size and address parameters must be a multiple of the + * host page size and the alignment must be a power of two or zero for default + * alignment. + * + * \param[out] ptr - Resulting pointer to start of virtual address range + * allocated \param[in] size - Size of the reserved virtual address range + * requested \param[in] alignment - Alignment of the reserved virtual address + * range requested \param[in] addr - Fixed starting address range + * requested \param[in] flags - Currently unused, must be zero \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemAddressFree + */ +CUresult CUDAAPI cuMemAddressReserve(CUdeviceptr *ptr, size_t size, + size_t alignment, CUdeviceptr addr, + unsigned long long flags); + +/** + * \brief Free an address range reservation. + * + * Frees a virtual address range reserved by cuMemAddressReserve. The size + * must match what was given to memAddressReserve and the ptr given must + * match what was returned from memAddressReserve. + * + * \param[in] ptr - Starting address of the virtual address range to free + * \param[in] size - Size of the virtual address region to free + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemAddressReserve + */ +CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size); + +/** + * \brief Create a CUDA memory handle representing a memory allocation of a + * given size described by the given properties + * + * This creates a memory allocation on the target device specified through the + * \p prop structure. The created allocation will not have any device or host + * mappings. The generic memory \p handle for the allocation can be + * mapped to the address space of calling process via ::cuMemMap. This handle + * cannot be transmitted directly to other processes (see + * ::cuMemExportToShareableHandle). On Windows, the caller must also pass + * an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which + * limits or allows access to this handle for a recipient process (see + * ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this + * allocation must be a multiple of the the value given via + * ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM + * flag. + * To create a CPU allocation targeting a specific host NUMA node, applications + * must set ::CUmemAllocationProp::CUmemLocation::type to + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA and + * ::CUmemAllocationProp::CUmemLocation::id must specify the NUMA ID of the CPU. + * On systems where NUMA is not available + * ::CUmemAllocationProp::CUmemLocation::id must be set to 0. + * + * Applications can set ::CUmemAllocationProp::requestedHandleTypes to + * ::CU_MEM_HANDLE_TYPE_FABRIC in order to create allocations suitable for + * sharing within an IMEX domain. An IMEX domain is either an OS instance or a + * group of securely connected OS instances using the NVIDIA IMEX daemon. An + * IMEX channel is a global resource within the IMEX domain that represents a + * logical entity that aims to provide fine grained accessibility control for + * the participating processes. When exporter and importer CUDA processes have + * been granted access to the same IMEX channel, they can securely share memory. + * If the allocating process does not have access setup for an IMEX channel, + * attempting to create a ::CUmemGenericAllocationHandle with + * ::CU_MEM_HANDLE_TYPE_FABRIC will result in ::CUDA_ERROR_NOT_PERMITTED. The + * nvidia-modprobe CLI provides more information regarding setting up of IMEX + * channels. + * + * If ::CUmemAllocationProp::allocFlags::usage contains + * ::CU_MEM_CREATE_USAGE_TILE_POOL flag then the memory allocation is intended + * only to be used as backing tile pool for sparse CUDA arrays and sparse CUDA + * mipmapped arrays. (see ::cuMemMapArrayAsync). + * + * \param[out] handle - Value of handle returned. All operations on this + * allocation are to be performed using this handle. \param[in] size - Size + * of the allocation requested \param[in] prop - Properties of the allocation + * to create. \param[in] flags - flags for future use, must be zero now. + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuMemRelease, ::cuMemExportToShareableHandle, + * ::cuMemImportFromShareableHandle + */ +CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags); + +/** + * \brief Release a memory handle representing a memory allocation which was + * previously allocated through cuMemCreate. + * + * Frees the memory that was allocated on a device through cuMemCreate. + * + * The memory allocation will be freed when all outstanding mappings to the + * memory are unmapped and when all outstanding references to the handle + * (including it's shareable counterparts) are also released. The generic memory + * handle can be freed when there are still outstanding mappings made with this + * handle. Each time a recipient process imports a shareable handle, it needs to + * pair it with + * ::cuMemRelease for the handle to be freed. If \p handle is not a valid + * handle the behavior is undefined. + * + * \param[in] handle Value of handle which was returned previously by + * cuMemCreate. \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuMemCreate + */ +CUresult CUDAAPI cuMemRelease(CUmemGenericAllocationHandle handle); + +/** + * \brief Maps an allocation handle to a reserved virtual address range. + * + * Maps bytes of memory represented by \p handle starting from byte \p offset to + * \p size to address range [\p addr, \p addr + \p size]. This range must be an + * address reservation previously reserved with ::cuMemAddressReserve, and + * \p offset + \p size must be less than the size of the memory allocation. + * Both \p ptr, \p size, and \p offset must be a multiple of the value given via + * ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM + * flag. If \p handle represents a multicast object, \p ptr, \p size and \p + * offset must be aligned to the value returned by ::cuMulticastGetGranularity + * with the flag + * ::CU_MULTICAST_MINIMUM_GRANULARITY. For best performance however, it is + * recommended that \p ptr, \p size and \p offset be aligned to the value + * returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_RECOMMENDED_GRANULARITY. + * + * Please note calling ::cuMemMap does not make the address accessible, + * the caller needs to update accessibility of a contiguous mapped VA + * range by calling ::cuMemSetAccess. + * + * Once a recipient process obtains a shareable memory handle + * from ::cuMemImportFromShareableHandle, the process must + * use ::cuMemMap to map the memory into its address ranges before + * setting accessibility with ::cuMemSetAccess. + * + * ::cuMemMap can only create mappings on VA range reservations + * that are not currently mapped. + * + * \param[in] ptr - Address where memory will be mapped. + * \param[in] size - Size of the memory mapping. + * \param[in] offset - Offset into the memory represented by + * - \p handle from which to start mapping + * - Note: currently must be zero. + * \param[in] handle - Handle to a shareable memory + * \param[in] flags - flags for future use, must be zero now. + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuMemUnmap, ::cuMemSetAccess, ::cuMemCreate, ::cuMemAddressReserve, + * ::cuMemImportFromShareableHandle + */ +CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, + CUmemGenericAllocationHandle handle, + unsigned long long flags); + +/** + * \brief Maps or unmaps subregions of sparse CUDA arrays and sparse CUDA + mipmapped arrays + * + * Performs map or unmap operations on subregions of sparse CUDA arrays and + sparse CUDA mipmapped arrays. + * Each operation is specified by a ::CUarrayMapInfo entry in the \p mapInfoList + array of size \p count. + * The structure ::CUarrayMapInfo is defined as follow: + \code + typedef struct CUarrayMapInfo_st { + CUresourcetype resourceType; + union { + CUmipmappedArray mipmap; + CUarray array; + } resource; + + CUarraySparseSubresourceType subresourceType; + union { + struct { + unsigned int level; + unsigned int layer; + unsigned int offsetX; + unsigned int offsetY; + unsigned int offsetZ; + unsigned int extentWidth; + unsigned int extentHeight; + unsigned int extentDepth; + } sparseLevel; + struct { + unsigned int layer; + unsigned long long offset; + unsigned long long size; + } miptail; + } subresource; + + CUmemOperationType memOperationType; + + CUmemHandleType memHandleType; + union { + CUmemGenericAllocationHandle memHandle; + } memHandle; + + unsigned long long offset; + unsigned int deviceBitMask; + unsigned int flags; + unsigned int reserved[2]; + } CUarrayMapInfo; + \endcode + * + * where ::CUarrayMapInfo::resourceType specifies the type of resource to be + operated on. + * If ::CUarrayMapInfo::resourceType is set to + ::CUresourcetype::CU_RESOURCE_TYPE_ARRAY then + * ::CUarrayMapInfo::resource::array must be set to a valid sparse CUDA array + handle. + * The CUDA array must be either a 2D, 2D layered or 3D CUDA array and must have + been allocated using + * ::cuArrayCreate or ::cuArray3DCreate with the flag ::CUDA_ARRAY3D_SPARSE + * or ::CUDA_ARRAY3D_DEFERRED_MAPPING. + * For CUDA arrays obtained using ::cuMipmappedArrayGetLevel, + ::CUDA_ERROR_INVALID_VALUE will be returned. + * If ::CUarrayMapInfo::resourceType is set to + ::CUresourcetype::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY + * then ::CUarrayMapInfo::resource::mipmap must be set to a valid sparse CUDA + mipmapped array handle. + * The CUDA mipmapped array must be either a 2D, 2D layered or 3D CUDA mipmapped + array and must have been + * allocated using ::cuMipmappedArrayCreate with the flag ::CUDA_ARRAY3D_SPARSE + * or ::CUDA_ARRAY3D_DEFERRED_MAPPING. + * + * ::CUarrayMapInfo::subresourceType specifies the type of subresource within + the resource. + * ::CUarraySparseSubresourceType_enum is defined as: + \code + typedef enum CUarraySparseSubresourceType_enum { + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL = 0, + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL = 1 + } CUarraySparseSubresourceType; + \endcode + * + * where + ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL + indicates a + * sparse-miplevel which spans at least one tile in every dimension. The + remaining miplevels which + * are too small to span at least one tile in any dimension constitute the mip + tail region as indicated by + * ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL + subresource type. + * + * If ::CUarrayMapInfo::subresourceType is set to + ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL + * then ::CUarrayMapInfo::subresource::sparseLevel struct must contain valid + array subregion offsets and extents. + * The ::CUarrayMapInfo::subresource::sparseLevel::offsetX, + ::CUarrayMapInfo::subresource::sparseLevel::offsetY + * and ::CUarrayMapInfo::subresource::sparseLevel::offsetZ must specify valid X, + Y and Z offsets respectively. + * The ::CUarrayMapInfo::subresource::sparseLevel::extentWidth, + ::CUarrayMapInfo::subresource::sparseLevel::extentHeight + * and ::CUarrayMapInfo::subresource::sparseLevel::extentDepth must specify + valid width, height and depth extents respectively. + * These offsets and extents must be aligned to the corresponding tile + dimension. + * For CUDA mipmapped arrays ::CUarrayMapInfo::subresource::sparseLevel::level + must specify a valid mip level index. Otherwise, + * must be zero. + * For layered CUDA arrays and layered CUDA mipmapped arrays + ::CUarrayMapInfo::subresource::sparseLevel::layer must specify a valid layer + index. Otherwise, + * must be zero. + * ::CUarrayMapInfo::subresource::sparseLevel::offsetZ must be zero and + ::CUarrayMapInfo::subresource::sparseLevel::extentDepth + * must be set to 1 for 2D and 2D layered CUDA arrays and CUDA mipmapped arrays. + * Tile extents can be obtained by calling ::cuArrayGetSparseProperties and + ::cuMipmappedArrayGetSparseProperties + * + * If ::CUarrayMapInfo::subresourceType is set to + ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL + * then ::CUarrayMapInfo::subresource::miptail struct must contain valid mip + tail offset in + * ::CUarrayMapInfo::subresource::miptail::offset and size in + ::CUarrayMapInfo::subresource::miptail::size. + * Both, mip tail offset and mip tail size must be aligned to the tile size. + * For layered CUDA mipmapped arrays which don't have the flag + ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL set in + ::CUDA_ARRAY_SPARSE_PROPERTIES::flags + * as returned by ::cuMipmappedArrayGetSparseProperties, + ::CUarrayMapInfo::subresource::miptail::layer must specify a valid layer index. + * Otherwise, must be zero. + * + * If ::CUarrayMapInfo::resource::array or ::CUarrayMapInfo::resource::mipmap + was created with ::CUDA_ARRAY3D_DEFERRED_MAPPING + * flag set the ::CUarrayMapInfo::subresourceType and the contents of + ::CUarrayMapInfo::subresource will be ignored. + * + * ::CUarrayMapInfo::memOperationType specifies the type of operation. + ::CUmemOperationType is defined as: \code typedef enum CUmemOperationType_enum + { CU_MEM_OPERATION_TYPE_MAP = 1, CU_MEM_OPERATION_TYPE_UNMAP = 2 } + CUmemOperationType; \endcode + * If ::CUarrayMapInfo::memOperationType is set to + ::CUmemOperationType::CU_MEM_OPERATION_TYPE_MAP then the subresource + * will be mapped onto the tile pool memory specified by + ::CUarrayMapInfo::memHandle at offset ::CUarrayMapInfo::offset. + * The tile pool allocation has to be created by specifying the + ::CU_MEM_CREATE_USAGE_TILE_POOL flag when calling ::cuMemCreate. Also, + * ::CUarrayMapInfo::memHandleType must be set to + ::CUmemHandleType::CU_MEM_HANDLE_TYPE_GENERIC. + * + * If ::CUarrayMapInfo::memOperationType is set to + ::CUmemOperationType::CU_MEM_OPERATION_TYPE_UNMAP then an unmapping operation + * is performed. ::CUarrayMapInfo::memHandle must be NULL. + * + * ::CUarrayMapInfo::deviceBitMask specifies the list of devices that must map + or unmap physical memory. + * Currently, this mask must have exactly one bit set, and the corresponding + device must match the device associated with the stream. + * If ::CUarrayMapInfo::memOperationType is set to + ::CUmemOperationType::CU_MEM_OPERATION_TYPE_MAP, the device must also match + * the device associated with the tile pool memory allocation as specified by + ::CUarrayMapInfo::memHandle. + * + * ::CUarrayMapInfo::flags and ::CUarrayMapInfo::reserved[] are unused and must + be set to zero. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \param[in] mapInfoList - List of ::CUarrayMapInfo + * \param[in] count - Count of ::CUarrayMapInfo in \p mapInfoList + * \param[in] hStream - Stream identifier for the stream to use for map or + unmap operations + * + * \sa ::cuMipmappedArrayCreate, ::cuArrayCreate, ::cuArray3DCreate, + ::cuMemCreate, ::cuArrayGetSparseProperties, + ::cuMipmappedArrayGetSparseProperties + */ +CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, + unsigned int count, CUstream hStream); + +/** + * \brief Unmap the backing memory of a given address range. + * + * The range must be the entire contiguous address range that was mapped to. In + * other words, ::cuMemUnmap cannot unmap a sub-range of an address range mapped + * by ::cuMemCreate / ::cuMemMap. Any backing memory allocations will be freed + * if there are no existing mappings and there are no unreleased memory handles. + * + * When ::cuMemUnmap returns successfully the address range is converted to an + * address reservation and can be used for a future calls to ::cuMemMap. Any + * new mapping to this virtual address will need to have access granted through + * ::cuMemSetAccess, as all mappings start with no accessibility setup. + * + * \param[in] ptr - Starting address for the virtual address range to unmap + * \param[in] size - Size of the virtual address range to unmap + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * \note_sync + * + * \sa ::cuMemCreate, ::cuMemAddressReserve + */ +CUresult CUDAAPI cuMemUnmap(CUdeviceptr ptr, size_t size); + +/** + * \brief Set the access flags for each location specified in \p desc for the + * given virtual address range + * + * Given the virtual address range via \p ptr and \p size, and the locations + * in the array given by \p desc and \p count, set the access flags for the + * target locations. The range must be a fully mapped address range + * containing all allocations created by ::cuMemMap / ::cuMemCreate. + * Users cannot specify ::CU_MEM_LOCATION_TYPE_HOST_NUMA accessibility for + * allocations created on with other location types. Note: When + * ::CUmemAccessDesc::CUmemLocation::type is ::CU_MEM_LOCATION_TYPE_HOST_NUMA, + * ::CUmemAccessDesc::CUmemLocation::id is ignored. When setting the access + * flags for a virtual address range mapping a multicast object, \p ptr and \p + * size must be aligned to the value returned by + * ::cuMulticastGetGranularity with the flag ::CU_MULTICAST_MINIMUM_GRANULARITY. + * For best performance however, it is recommended that \p ptr and \p size be + * aligned to the value returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_RECOMMENDED_GRANULARITY. + * + * \param[in] ptr - Starting address for the virtual address range + * \param[in] size - Length of the virtual address range + * \param[in] desc - Array of ::CUmemAccessDesc that describe how to change the + * - mapping for each location specified + * \param[in] count - Number of ::CUmemAccessDesc in \p desc + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * \note_sync + * + * \sa ::cuMemSetAccess, ::cuMemCreate, :cuMemMap + */ +CUresult CUDAAPI cuMemSetAccess(CUdeviceptr ptr, size_t size, + const CUmemAccessDesc *desc, size_t count); + +/** + * \brief Get the access \p flags set for the given \p location and \p ptr + * + * \param[out] flags - Flags set for this location + * \param[in] location - Location in which to check the flags for + * \param[in] ptr - Address in which to check the access flags for + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemSetAccess + */ +CUresult CUDAAPI cuMemGetAccess(unsigned long long *flags, + const CUmemLocation *location, CUdeviceptr ptr); + +/** + * \brief Exports an allocation to a requested shareable handle type + * + * Given a CUDA memory handle, create a shareable memory + * allocation handle that can be used to share the memory with other + * processes. The recipient process can convert the shareable handle back into a + * CUDA memory handle using ::cuMemImportFromShareableHandle and map + * it with ::cuMemMap. The implementation of what this handle is and how it + * can be transferred is defined by the requested handle type in \p handleType + * + * Once all shareable handles are closed and the allocation is released, the + * allocated memory referenced will be released back to the OS and uses of the + * CUDA handle afterward will lead to undefined behavior. + * + * This API can also be used in conjunction with other APIs (e.g. Vulkan, + * OpenGL) that support importing memory from the shareable type + * + * \param[out] shareableHandle - Pointer to the location in which to store the + * requested handle type \param[in] handle - CUDA handle for the + * memory allocation \param[in] handleType - Type of shareable handle + * requested (defines type and size of the \p shareableHandle output parameter) + * \param[in] flags - Reserved, must be zero + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemImportFromShareableHandle + */ +CUresult CUDAAPI cuMemExportToShareableHandle( + void *shareableHandle, CUmemGenericAllocationHandle handle, + CUmemAllocationHandleType handleType, unsigned long long flags); + +/** + * \brief Imports an allocation from a requested shareable handle type. + * + * If the current process cannot support the memory described by this shareable + * handle, this API will error as ::CUDA_ERROR_NOT_SUPPORTED. + * + * If \p shHandleType is ::CU_MEM_HANDLE_TYPE_FABRIC and the importer process + * has not been granted access to the same IMEX channel as the exporter process, + * this API will error as ::CUDA_ERROR_NOT_PERMITTED. + * + * \note Importing shareable handles exported from some graphics APIs(VUlkan, + * OpenGL, etc) created on devices under an SLI group may not be supported, and + * thus this API will return CUDA_ERROR_NOT_SUPPORTED. There is no guarantee + * that the contents of \p handle will be the same CUDA memory handle for the + * same given OS shareable handle, or the same underlying allocation. + * + * \param[out] handle - CUDA Memory handle for the memory allocation. + * \param[in] osHandle - Shareable Handle representing the memory + * allocation that is to be imported. \param[in] shHandleType - handle type of + * the exported handle ::CUmemAllocationHandleType. \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemExportToShareableHandle, ::cuMemMap, ::cuMemRelease + */ +CUresult CUDAAPI cuMemImportFromShareableHandle( + CUmemGenericAllocationHandle *handle, void *osHandle, + CUmemAllocationHandleType shHandleType); + +/** + * \brief Calculates either the minimal or recommended granularity + * + * Calculates either the minimal or recommended granularity + * for a given allocation specification and returns it in granularity. This + * granularity can be used as a multiple for alignment, size, or address + * mapping. + * + * \param[out] granularity Returned granularity. + * \param[in] prop Property for which to determine the granularity for + * \param[in] option Determines which granularity to return + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemCreate, ::cuMemMap + */ +CUresult CUDAAPI cuMemGetAllocationGranularity( + size_t *granularity, const CUmemAllocationProp *prop, + CUmemAllocationGranularity_flags option); + +/** + * \brief Retrieve the contents of the property structure defining properties + * for this handle + * + * \param[out] prop - Pointer to a properties structure which will hold the + * information about this handle \param[in] handle - Handle which to perform the + * query on \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemCreate, ::cuMemImportFromShareableHandle + */ +CUresult CUDAAPI cuMemGetAllocationPropertiesFromHandle( + CUmemAllocationProp *prop, CUmemGenericAllocationHandle handle); + +/** + * \brief Given an address \p addr, returns the allocation handle of the backing + * memory allocation. + * + * The handle is guaranteed to be the same handle value used to map the memory. + * If the address requested is not mapped, the function will fail. The returned + * handle must be released with corresponding number of calls to ::cuMemRelease. + * + * \note The address \p addr, can be any address in a range previously mapped + * by ::cuMemMap, and not necessarily the start address. + * + * \param[out] handle CUDA Memory handle for the backing memory allocation. + * \param[in] addr Memory address to query, that has been mapped previously. + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMemCreate, ::cuMemRelease, ::cuMemMap + */ +CUresult CUDAAPI +cuMemRetainAllocationHandle(CUmemGenericAllocationHandle *handle, void *addr); + +/** @} */ /* END CUDA_VA */ + +/** + * \defgroup CUDA_MALLOC_ASYNC Stream Ordered Memory Allocator + * + * ___MANBRIEF___ Functions for performing allocation and free operations in + * stream order. Functions for controlling the behavior of the underlying + * allocator. + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream ordered memory allocator exposed by the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_MALLOC_ASYNC_overview overview + * + * The asynchronous allocator allows the user to allocate and free in stream + * order. All asynchronous accesses of the allocation must happen between the + * stream executions of the allocation and the free. If the memory is accessed + * outside of the promised stream order, a use before allocation / use after + * free error will cause undefined behavior. + * + * The allocator is free to reallocate the memory as long as it can guarantee + * that compliant memory accesses will not overlap temporally. + * The allocator may refer to internal stream ordering as well as inter-stream + * dependencies (such as CUDA events and null stream dependencies) when + * establishing the temporal guarantee. The allocator may also insert + * inter-stream dependencies to establish the temporal guarantee. + * + * \section CUDA_MALLOC_ASYNC_support Supported Platforms + * + * Whether or not a device supports the integrated stream ordered memory + * allocator may be queried by calling ::cuDeviceGetAttribute() with the device + * attribute + * ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED + */ + +/** + * \brief Frees memory with stream ordered semantics + * + * Inserts a free operation into \p hStream. + * The allocation must not be accessed after stream execution reaches the free. + * After this API returns, accessing the memory from any subsequent work + * launched on the GPU or querying its pointer attributes results in undefined + * behavior. + * + * \note During stream capture, this function results in the creation of a free + * node and must therefore be passed the address of a graph allocation. + * + * \param dptr - memory to free + * \param hStream - The stream establishing the stream ordering contract. + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current + * context), + * ::CUDA_ERROR_NOT_SUPPORTED + */ +CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); + +/** + * \brief Allocates memory with stream ordered semantics + * + * Inserts an allocation operation into \p hStream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the the allocation operation + * completes. The allocation comes from the memory pool current to the stream's + * device. + * + * \note The default memory pool of a device contains device memory from that + * device. \note Basic stream ordering allows future work submitted into the + * same stream to use the allocation. Stream query, stream synchronize, and CUDA + * events can be used to guarantee that the allocation operation completes + * before work submitted in a separate stream runs. \note During stream capture, + * this function results in the creation of an allocation node. In this case, + * the allocation is owned by the graph instead of the memory pool. The + * memory pool's properties are used to set the node's creation parameters. + * + * \param[out] dptr - Returned device pointer + * \param[in] bytesize - Number of bytes to allocate + * \param[in] hStream - The stream establishing the stream ordering contract + * and the memory pool to allocate from \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current + * context), + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemAllocFromPoolAsync, ::cuMemFreeAsync, ::cuDeviceSetMemPool, + * ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolSetAccess, ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemAllocAsync(CUdeviceptr *dptr, size_t bytesize, + CUstream hStream); + +/** + * \brief Tries to release memory back to the OS + * + * Releases memory back to the OS until the pool contains fewer than + * minBytesToKeep reserved bytes, or there is no more memory that the allocator + * can safely release. The allocator cannot release OS allocations that back + * outstanding asynchronous allocations. The OS allocations may happen at + * different granularity from the user allocations. + * + * \note: Allocations that have not been freed count as outstanding. + * \note: Allocations that have been asynchronously freed but whose completion + * has not been observed on the host (eg. by a synchronize) can count as + * outstanding. + * + * \param[in] pool - The memory pool to trim + * \param[in] minBytesToKeep - If the pool has less than minBytesToKeep + * reserved, the TrimTo operation is a no-op. Otherwise the pool will be + * guaranteed to have at least minBytesToKeep bytes reserved after the + * operation. \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolTrimTo(CUmemoryPool pool, size_t minBytesToKeep); + +/** + * \brief Sets attributes of a memory pool + * + * Supported attributes are: + * - ::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before + * trying to release memory back to the OS. When more than the release threshold + * bytes of memory are held by the memory pool, the allocator will try to + * release memory back to the OS on the next call to stream, event or context + * synchronize. (default 0) + * - ::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to use memory asynchronously freed + * in another stream as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the + * required stream ordered dependencies. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC: (value type = int) + * Allow reuse of already completed frees when there is no + * dependency between the free and allocation. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to + * reuse a piece of memory released by ::cuMemFreeAsync (default enabled). + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: (value type = cuuint64_t) + * Reset the high watermark that tracks the amount of backing + * memory that was allocated for the memory pool. It is illegal to set this + * attribute to a non-zero value. + * - ::CU_MEMPOOL_ATTR_USED_MEM_HIGH: (value type = cuuint64_t) + * Reset the high watermark that tracks the amount of used + * memory that was allocated for the memory pool. + * + * \param[in] pool - The memory pool to modify + * \param[in] attr - The attribute to modify + * \param[in] value - Pointer to the value to assign + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolSetAttribute(CUmemoryPool pool, + CUmemPool_attribute attr, void *value); + +/** + * \brief Gets attributes of a memory pool + * + * Supported attributes are: + * - ::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before + * trying to release memory back to the OS. When more than the release threshold + * bytes of memory are held by the memory pool, the allocator will try to + * release memory back to the OS on the next call to stream, event or context + * synchronize. (default 0) + * - ::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to use memory asynchronously freed + * in another stream as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the + * required stream ordered dependencies. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC: (value type = int) + * Allow reuse of already completed frees when there is no + * dependency between the free and allocation. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to + * reuse a piece of memory released by ::cuMemFreeAsync (default enabled). + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT: (value type = cuuint64_t) + * Amount of backing memory currently allocated for the + * mempool + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: (value type = cuuint64_t) + * High watermark of backing memory allocated for the mempool + * since the last time it was reset. + * - ::CU_MEMPOOL_ATTR_USED_MEM_CURRENT: (value type = cuuint64_t) + * Amount of memory from the pool that is currently in use by + * the application. + * - ::CU_MEMPOOL_ATTR_USED_MEM_HIGH: (value type = cuuint64_t) + * High watermark of the amount of memory from the pool that + * was in use by the application. + * + * \param[in] pool - The memory pool to get attributes of + * \param[in] attr - The attribute to get + * \param[out] value - Retrieved value + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolGetAttribute(CUmemoryPool pool, + CUmemPool_attribute attr, void *value); + +/** + * \brief Controls visibility of pools between devices + * + * \param[in] pool - The pool being modified + * \param[in] map - Array of access descriptors. Each descriptor instructs the + * access to enable for a single gpu. \param[in] count - Number of descriptors + * in the map array. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolSetAccess(CUmemoryPool pool, + const CUmemAccessDesc *map, size_t count); + +/** + * \brief Returns the accessibility of a pool from a device + * + * Returns the accessibility of the pool's memory from the specified location. + * + * \param[out] flags - the accessibility of the pool from the specified + * location \param[in] memPool - the pool being queried \param[in] location - + * the location accessing the pool + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolGetAccess(CUmemAccess_flags *flags, + CUmemoryPool memPool, + CUmemLocation *location); + +/** + * \brief Creates a memory pool + * + * Creates a CUDA memory pool and returns the handle in \p pool. The \p + * poolProps determines the properties of the pool such as the backing device + * and IPC capabilities. + * + * To create a memory pool targeting a specific host NUMA node, applications + * must set ::CUmemPoolProps::CUmemLocation::type to + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA and + * ::CUmemPoolProps::CUmemLocation::id must specify the NUMA ID of the host + * memory node. By default, the pool's memory will be accessible from the device + * it is allocated on. In the case of pools created with + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA, their default accessibility will be from + * the host CPU. Applications can control the maximum size of the pool by + * specifying a non-zero value for ::CUmemPoolProps::maxSize. If set to 0, the + * maximum size of the pool will default to a system dependent value. + * + * Applications can set ::CUmemPoolProps::handleTypes to + * ::CU_MEM_HANDLE_TYPE_FABRIC in order to create ::CUmemoryPool suitable for + * sharing within an IMEX domain. An IMEX domain is either an OS instance or a + * group of securely connected OS instances using the NVIDIA IMEX daemon. An + * IMEX channel is a global resource within the IMEX domain that represents a + * logical entity that aims to provide fine grained accessibility control for + * the participating processes. When exporter and importer CUDA processes have + * been granted access to the same IMEX channel, they can securely share memory. + * If the allocating process does not have access setup for an IMEX channel, + * attempting to export a ::CUmemoryPool with ::CU_MEM_HANDLE_TYPE_FABRIC will + * result in ::CUDA_ERROR_NOT_PERMITTED. The nvidia-modprobe CLI provides more + * information regarding setting up of IMEX channels. + * + * \note Specifying CU_MEM_HANDLE_TYPE_NONE creates a memory pool that will not + * support IPC. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NOT_PERMITTED + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuDeviceSetMemPool, ::cuDeviceGetMemPool, ::cuDeviceGetDefaultMemPool, + * ::cuMemAllocFromPoolAsync, ::cuMemPoolExportToShareableHandle + */ +CUresult CUDAAPI cuMemPoolCreate(CUmemoryPool *pool, + const CUmemPoolProps *poolProps); + +/** + * \brief Destroys the specified memory pool + * + * If any pointers obtained from this pool haven't been freed or + * the pool has free operations that haven't completed + * when ::cuMemPoolDestroy is invoked, the function will return immediately and + * the resources associated with the pool will be released automatically once + * there are no more outstanding allocations. + * + * Destroying the current mempool of a device sets the default mempool of + * that device as the current mempool for that device. + * + * \note A device's default memory pool cannot be destroyed. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemFreeAsync, ::cuDeviceSetMemPool, ::cuDeviceGetMemPool, + * ::cuDeviceGetDefaultMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolDestroy(CUmemoryPool pool); + +/** + * \brief Allocates memory from a specified pool with stream ordered semantics. + * + * Inserts an allocation operation into \p hStream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the the allocation operation + * completes. The allocation comes from the specified memory pool. + * + * \note + * - The specified memory pool may be from a device different than that of + * the specified \p hStream. + * + * - Basic stream ordering allows future work submitted into the same stream + * to use the allocation. Stream query, stream synchronize, and CUDA events can + * be used to guarantee that the allocation operation completes before work + * submitted in a separate stream runs. + * + * \note During stream capture, this function results in the creation of an + * allocation node. In this case, the allocation is owned by the graph instead + * of the memory pool. The memory pool's properties are used to set the node's + * creation parameters. + * + * \param[out] dptr - Returned device pointer + * \param[in] bytesize - Number of bytes to allocate + * \param[in] pool - The pool to allocate from + * \param[in] hStream - The stream establishing the stream ordering semantic + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current + * context), + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate, ::cuMemPoolSetAccess, + * ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, + CUmemoryPool pool, CUstream hStream); + +/** + * \brief Exports a memory pool to the requested handle type. + * + * Given an IPC capable mempool, create an OS handle to share the pool with + * another process. A recipient process can convert the shareable handle into a + * mempool with ::cuMemPoolImportFromShareableHandle. Individual pointers can + * then be shared with the ::cuMemPoolExportPointer and ::cuMemPoolImportPointer + * APIs. The implementation of what the shareable handle is and how it can be + * transferred is defined by the requested handle type. + * + * \note: To create an IPC capable mempool, create a mempool with a + * CUmemAllocationHandleType other than CU_MEM_HANDLE_TYPE_NONE. + * + * \param[out] handle_out - Returned OS handle + * \param[in] pool - pool to export + * \param[in] handleType - the type of handle to create + * \param[in] flags - must be 0 + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolImportFromShareableHandle, ::cuMemPoolExportPointer, + * ::cuMemPoolImportPointer, ::cuMemAllocAsync, ::cuMemFreeAsync, + * ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolSetAccess, ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemPoolExportToShareableHandle( + void *handle_out, CUmemoryPool pool, CUmemAllocationHandleType handleType, + unsigned long long flags); + +/** + * \brief imports a memory pool from a shared handle. + * + * Specific allocations can be imported from the imported pool with + * cuMemPoolImportPointer. + * + * If \p handleType is ::CU_MEM_HANDLE_TYPE_FABRIC and the importer process has + * not been granted access to the same IMEX channel as the exporter process, + * this API will error as ::CUDA_ERROR_NOT_PERMITTED. + * + * + * \note Imported memory pools do not support creating new allocations. + * As such imported memory pools may not be used in cuDeviceSetMemPool + * or ::cuMemAllocFromPoolAsync calls. + * + * \param[out] pool_out - Returned memory pool + * \param[in] handle - OS handle of the pool to open + * \param[in] handleType - The type of handle being imported + * \param[in] flags - must be 0 + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolExportPointer, + * ::cuMemPoolImportPointer + */ +CUresult CUDAAPI cuMemPoolImportFromShareableHandle( + CUmemoryPool *pool_out, void *handle, CUmemAllocationHandleType handleType, + unsigned long long flags); + +/** + * \brief Export data to share a memory pool allocation between processes. + * + * Constructs \p shareData_out for sharing a specific allocation from an already + * shared memory pool. The recipient process can import the allocation with the + * ::cuMemPoolImportPointer api. The data is not a handle and may be shared + * through any IPC mechanism. + * + * \param[out] shareData_out - Returned export data + * \param[in] ptr - pointer to memory being exported + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolImportFromShareableHandle, + * ::cuMemPoolImportPointer + */ +CUresult CUDAAPI cuMemPoolExportPointer(CUmemPoolPtrExportData *shareData_out, + CUdeviceptr ptr); + +/** + * \brief Import a memory pool allocation from another process. + * + * Returns in \p ptr_out a pointer to the imported memory. + * The imported memory must not be accessed before the allocation operation + * completes in the exporting process. The imported memory must be freed from + * all importing processes before being freed in the exporting process. The + * pointer may be freed with cuMemFree or cuMemFreeAsync. If cuMemFreeAsync is + * used, the free must be completed on the importing process before the free + * operation on the exporting process. + * + * \note The cuMemFreeAsync api may be used in the exporting process before + * the cuMemFreeAsync operation completes in its stream as long as the + * cuMemFreeAsync in the exporting process specifies a stream with + * a stream dependency on the importing process's cuMemFreeAsync. + * + * \param[out] ptr_out - pointer to imported memory + * \param[in] pool - pool from which to import + * \param[in] shareData - data specifying the memory to import + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolImportFromShareableHandle, + * ::cuMemPoolExportPointer + */ +CUresult CUDAAPI cuMemPoolImportPointer(CUdeviceptr *ptr_out, CUmemoryPool pool, + CUmemPoolPtrExportData *shareData); + +/** @} */ /* END CUDA_MALLOC_ASYNC */ + +/** + * \defgroup CUDA_MULTICAST Multicast Object Management + * + * ___MANBRIEF___ Functions for creating multicast objects, adding devices to + * them and binding/unbinding memory + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the CUDA multicast object operations exposed by the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_MULTICAST_overview overview + * + * A multicast object created via ::cuMulticastCreate enables certain memory + * operations to be broadcast to a team of devices. Devices can be added to a + * multicast object via ::cuMulticastAddDevice. Memory can be bound on each + * participating device via either ::cuMulticastBindMem or + * ::cuMulticastBindAddr. Multicast objects can be mapped into a device's + * virtual address space using the virtual memory management APIs (see + * ::cuMemMap and ::cuMemSetAccess). + * + * \section CUDA_MULTICAST_support Supported Platforms + * + * Support for multicast on a specific device can be queried using the device + * attribute ::CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED + */ + +/** + * \brief Create a generic allocation handle representing a multicast object + * described by the given properties. + * + * This creates a multicast object as described by \p prop. The number of + * participating devices is specified by ::CUmulticastObjectProp::numDevices. + * Devices can be added to the multicast object via ::cuMulticastAddDevice. + * All participating devices must be added to the multicast object before memory + * can be bound to it. Memory is bound to the multicast object via either + * ::cuMulticastBindMem or ::cuMulticastBindAddr, and can be unbound via + * ::cuMulticastUnbind. The total amount of memory that can be bound per device + * is specified by :CUmulticastObjectProp::size. This size must be a multiple of + * the value returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_GRANULARITY_MINIMUM. For best performance however, the size + * should be aligned to the value returned by ::cuMulticastGetGranularity with + * the flag ::CU_MULTICAST_GRANULARITY_RECOMMENDED. + * + * After all participating devices have been added, multicast objects can also + * be mapped to a device's virtual address space using the virtual memory + * management APIs (see ::cuMemMap and ::cuMemSetAccess). Multicast objects can + * also be shared with other processes by requesting a shareable handle via + * ::cuMemExportToShareableHandle. Note that the desired types of shareable + * handles must be specified in the bitmask + * ::CUmulticastObjectProp::handleTypes. Multicast objects can be released using + * the virtual memory management API + * ::cuMemRelease. + * + * \param[out] mcHandle Value of handle returned. + * \param[in] prop Properties of the multicast object to create. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastAddDevice, ::cuMulticastBindMem, ::cuMulticastBindAddr, + * ::cuMulticastUnbind \sa ::cuMemCreate, ::cuMemRelease, + * ::cuMemExportToShareableHandle, ::cuMemImportFromShareableHandle + */ +CUresult CUDAAPI cuMulticastCreate(CUmemGenericAllocationHandle *mcHandle, + const CUmulticastObjectProp *prop); + +/** + * \brief Associate a device to a multicast object. + * + * Associates a device to a multicast object. The added device will be a part of + * the multicast team of size specified by CUmulticastObjectProp::numDevices + * during ::cuMulticastCreate. + * The association of the device to the multicast object is permanent during + * the life time of the multicast object. + * All devices must be added to the multicast team before any memory can be + * bound to any device in the team. Any calls to ::cuMulticastBindMem or + * ::cuMulticastBindAddr will block until all devices have been added. + * Similarly all devices must be added to the multicast team before a virtual + * address range can be mapped to the multicast object. A call to ::cuMemMap + * will block until all devices have been added. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] dev Device that will be associated to the multicast + * object. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastCreate, ::cuMulticastBindMem, ::cuMulticastBindAddr + */ +CUresult CUDAAPI cuMulticastAddDevice(CUmemGenericAllocationHandle mcHandle, + CUdevice dev); + +/** + * \brief Bind a memory allocation represented by a handle to a multicast + * object. + * + * Binds a memory allocation specified by \p memHandle and created via + * ::cuMemCreate to a multicast object represented by \p mcHandle and created + * via ::cuMulticastCreate. The intended \p size of the bind, the offset in the + * multicast range \p mcOffset as well as the offset in the memory \p memOffset + * must be a multiple of the value returned by ::cuMulticastGetGranularity with + * the flag ::CU_MULTICAST_GRANULARITY_MINIMUM. For best performance however, + * \p size, \p mcOffset and \p memOffset should be aligned to the granularity of + * the memory allocation(see ::cuMemGetAllocationGranularity) or to the value + * returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_GRANULARITY_RECOMMENDED. + * + * The \p size + \p memOffset must be smaller than the size of the allocated + * memory. Similarly the \p size + \p mcOffset must be smaller than the size + * of the multicast object. + * The memory allocation must have been created on one of the devices + * that was added to the multicast team via ::cuMulticastAddDevice. + * Externally shareable as well as imported multicast objects can be bound only + * to externally shareable memory. + * Note that this call will return CUDA_ERROR_OUT_OF_MEMORY if there are + * insufficient resources required to perform the bind. This call may also + * return CUDA_ERROR_SYSTEM_NOT_READY if the necessary system software is not + * initialized or running. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] mcOffset Offset into the multicast object for attachment. + * \param[in] memHandle Handle representing a memory allocation. + * \param[in] memOffset Offset into the memory for attachment. + * \param[in] size Size of the memory that will be bound to the + * multicast object. + * \param[in] flags Flags for future use, must be zero for now. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_SYSTEM_NOT_READY + * + * \sa ::cuMulticastCreate, ::cuMulticastAddDevice, ::cuMemCreate + */ +CUresult CUDAAPI cuMulticastBindMem(CUmemGenericAllocationHandle mcHandle, + size_t mcOffset, + CUmemGenericAllocationHandle memHandle, + size_t memOffset, size_t size, + unsigned long long flags); + +/** + * \brief Bind a memory allocation represented by a virtual address to a + * multicast object. + * + * Binds a memory allocation specified by its mapped address \p memptr to a + * multicast object represented by \p mcHandle. + * The memory must have been allocated via ::cuMemCreate or ::cudaMallocAsync. + * The intended \p size of the bind, the offset in the multicast range + * \p mcOffset and \p memptr must be a multiple of the value returned by + * ::cuMulticastGetGranularity with the flag ::CU_MULTICAST_GRANULARITY_MINIMUM. + * For best performance however, \p size, \p mcOffset and \p memptr should be + * aligned to the value returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_GRANULARITY_RECOMMENDED. + * + * The \p size must be smaller than the size of the allocated memory. + * Similarly the \p size + \p mcOffset must be smaller than the total size + * of the multicast object. + * The memory allocation must have been created on one of the devices + * that was added to the multicast team via ::cuMulticastAddDevice. + * Externally shareable as well as imported multicast objects can be bound only + * to externally shareable memory. + * Note that this call will return CUDA_ERROR_OUT_OF_MEMORY if there are + * insufficient resources required to perform the bind. This call may also + * return CUDA_ERROR_SYSTEM_NOT_READY if the necessary system software is not + * initialized or running. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] mcOffset Offset into multicast va range for attachment. + * \param[in] memptr Virtual address of the memory allocation. + * \param[in] size Size of memory that will be bound to the + * multicast object. + * \param[in] flags Flags for future use, must be zero now. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_SYSTEM_NOT_READY + * + * \sa ::cuMulticastCreate, ::cuMulticastAddDevice, ::cuMemCreate + */ +CUresult CUDAAPI cuMulticastBindAddr(CUmemGenericAllocationHandle mcHandle, + size_t mcOffset, CUdeviceptr memptr, + size_t size, unsigned long long flags); + +/** + * \brief Unbind any memory allocations bound to a multicast object at a given + * offset and upto a given size. + * + * Unbinds any memory allocations hosted on \p dev and bound to a multicast + * object at \p mcOffset and upto a given \p size. + * The intended \p size of the unbind and the offset in the multicast range + * ( \p mcOffset ) must be a multiple of the value returned by + * ::cuMulticastGetGranularity flag ::CU_MULTICAST_GRANULARITY_MINIMUM. + * The \p size + \p mcOffset must be smaller than the total size of the + * multicast object. + * + * \note + * Warning: + * The \p mcOffset and the \p size must match the corresponding values specified + * during the bind call. Any other values may result in undefined behavior. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] dev Device that hosts the memory allocation. + * \param[in] mcOffset Offset into the multicast object. + * \param[in] size Desired size to unbind. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastBindMem, ::cuMulticastBindAddr + */ +CUresult CUDAAPI cuMulticastUnbind(CUmemGenericAllocationHandle mcHandle, + CUdevice dev, size_t mcOffset, size_t size); + +/** + * \brief Calculates either the minimal or recommended granularity for multicast + * object + * + * Calculates either the minimal or recommended granularity for a given set of + * multicast object properties and returns it in granularity. This granularity + * can be used as a multiple for size, bind offsets and address mappings of the + * multicast object. + * + * \param[out] granularity Returned granularity. + * \param[in] prop Properties of the multicast object. + * \param[in] option Determines which granularity to return. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastCreate, ::cuMulticastBindMem, ::cuMulticastBindAddr, + * ::cuMulticastUnbind + */ +CUresult CUDAAPI cuMulticastGetGranularity(size_t *granularity, + const CUmulticastObjectProp *prop, + CUmulticastGranularity_flags option); + +/** @} */ /* END CUDA_MULTICAST */ + +/** + * \defgroup CUDA_UNIFIED Unified Addressing + * + * ___MANBRIEF___ unified addressing functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the unified addressing functions of the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_UNIFIED_overview Overview + * + * CUDA devices can share a unified address space with the host. + * For these devices there is no distinction between a device + * pointer and a host pointer -- the same pointer value may be + * used to access memory from the host program and from a kernel + * running on the device (with exceptions enumerated below). + * + * \section CUDA_UNIFIED_support Supported Platforms + * + * Whether or not a device supports unified addressing may be + * queried by calling ::cuDeviceGetAttribute() with the device + * attribute ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING. + * + * Unified addressing is automatically enabled in 64-bit processes + * + * \section CUDA_UNIFIED_lookup Looking Up Information from Pointer Values + * + * It is possible to look up information about the memory which backs a + * pointer value. For instance, one may want to know if a pointer points + * to host or device memory. As another example, in the case of device + * memory, one may want to know on which CUDA device the memory + * resides. These properties may be queried using the function + * ::cuPointerGetAttribute() + * + * Since pointers are unique, it is not necessary to specify information + * about the pointers specified to the various copy functions in the + * CUDA API. The function ::cuMemcpy() may be used to perform a copy + * between two pointers, ignoring whether they point to host or device + * memory (making ::cuMemcpyHtoD(), ::cuMemcpyDtoD(), and ::cuMemcpyDtoH() + * unnecessary for devices supporting unified addressing). For + * multidimensional copies, the memory type ::CU_MEMORYTYPE_UNIFIED may be + * used to specify that the CUDA driver should infer the location of the + * pointer from its value. + * + * \section CUDA_UNIFIED_automaphost Automatic Mapping of Host Allocated Host + * Memory + * + * All host memory allocated in all contexts using ::cuMemAllocHost() and + * ::cuMemHostAlloc() is always directly accessible from all contexts on + * all devices that support unified addressing. This is the case regardless + * of whether or not the flags ::CU_MEMHOSTALLOC_PORTABLE and + * ::CU_MEMHOSTALLOC_DEVICEMAP are specified. + * + * The pointer value through which allocated host memory may be accessed + * in kernels on all devices that support unified addressing is the same + * as the pointer value through which that memory is accessed on the host, + * so it is not necessary to call ::cuMemHostGetDevicePointer() to get the + * device pointer for these allocations. + * + * Note that this is not the case for memory allocated using the flag + * ::CU_MEMHOSTALLOC_WRITECOMBINED, as discussed below. + * + * \section CUDA_UNIFIED_autopeerregister Automatic Registration of Peer Memory + * + * Upon enabling direct access from a context that supports unified addressing + * to another peer context that supports unified addressing using + * ::cuCtxEnablePeerAccess() all memory allocated in the peer context using + * ::cuMemAlloc() and ::cuMemAllocPitch() will immediately be accessible + * by the current context. The device pointer value through + * which any peer memory may be accessed in the current context + * is the same pointer value through which that memory may be + * accessed in the peer context. + * + * \section CUDA_UNIFIED_exceptions Exceptions, Disjoint Addressing + * + * Not all memory may be accessed on devices through the same pointer + * value through which they are accessed on the host. These exceptions + * are host memory registered using ::cuMemHostRegister() and host memory + * allocated using the flag ::CU_MEMHOSTALLOC_WRITECOMBINED. For these + * exceptions, there exists a distinct host and device address for the + * memory. The device address is guaranteed to not overlap any valid host + * pointer range and is guaranteed to have the same value across all + * contexts that support unified addressing. + * + * This device address may be queried using ::cuMemHostGetDevicePointer() + * when a context using unified addressing is current. Either the host + * or the unified device pointer value may be used to refer to this memory + * through ::cuMemcpy() and similar functions using the + * ::CU_MEMORYTYPE_UNIFIED memory type. + * + */ + +/** + * \brief Returns information about a pointer + * + * The supported attributes are: + * + * - ::CU_POINTER_ATTRIBUTE_CONTEXT: + * + * Returns in \p *data the ::CUcontext in which \p ptr was allocated or + * registered. + * The type of \p data must be ::CUcontext *. + * + * If \p ptr was not allocated by, mapped by, or registered with + * a ::CUcontext which uses unified virtual addressing then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * - ::CU_POINTER_ATTRIBUTE_MEMORY_TYPE: + * + * Returns in \p *data the physical memory type of the memory that + * \p ptr addresses as a ::CUmemorytype enumerated value. + * The type of \p data must be unsigned int. + * + * If \p ptr addresses device memory then \p *data is set to + * ::CU_MEMORYTYPE_DEVICE. The particular ::CUdevice on which the + * memory resides is the ::CUdevice of the ::CUcontext returned by the + * ::CU_POINTER_ATTRIBUTE_CONTEXT attribute of \p ptr. + * + * If \p ptr addresses host memory then \p *data is set to + * ::CU_MEMORYTYPE_HOST. + * + * If \p ptr was not allocated by, mapped by, or registered with + * a ::CUcontext which uses unified virtual addressing then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * If the current ::CUcontext does not support unified virtual + * addressing then ::CUDA_ERROR_INVALID_CONTEXT is returned. + * + * - ::CU_POINTER_ATTRIBUTE_DEVICE_POINTER: + * + * Returns in \p *data the device pointer value through which + * \p ptr may be accessed by kernels running in the current + * ::CUcontext. + * The type of \p data must be CUdeviceptr *. + * + * If there exists no device pointer value through which + * kernels running in the current ::CUcontext may access + * \p ptr then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * If there is no current ::CUcontext then + * ::CUDA_ERROR_INVALID_CONTEXT is returned. + * + * Except in the exceptional disjoint addressing cases discussed + * below, the value returned in \p *data will equal the input + * value \p ptr. + * + * - ::CU_POINTER_ATTRIBUTE_HOST_POINTER: + * + * Returns in \p *data the host pointer value through which + * \p ptr may be accessed by by the host program. + * The type of \p data must be void **. + * If there exists no host pointer value through which + * the host program may directly access \p ptr then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Except in the exceptional disjoint addressing cases discussed + * below, the value returned in \p *data will equal the input + * value \p ptr. + * + * - ::CU_POINTER_ATTRIBUTE_P2P_TOKENS: + * + * Returns in \p *data two tokens for use with the nv-p2p.h Linux + * kernel interface. \p data must be a struct of type + * CUDA_POINTER_ATTRIBUTE_P2P_TOKENS. + * + * \p ptr must be a pointer to memory obtained from :cuMemAlloc(). + * Note that p2pToken and vaSpaceToken are only valid for the + * lifetime of the source allocation. A subsequent allocation at + * the same address may return completely different tokens. + * Querying this attribute has a side effect of setting the attribute + * ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS for the region of memory that + * \p ptr points to. + * + * - ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS: + * + * A boolean attribute which when set, ensures that synchronous memory + * operations initiated on the region of memory that \p ptr points to will + * always synchronize. See further documentation in the section titled "API + * synchronization behavior" to learn more about cases when synchronous memory + * operations can exhibit asynchronous behavior. + * + * - ::CU_POINTER_ATTRIBUTE_BUFFER_ID: + * + * Returns in \p *data a buffer ID which is guaranteed to be unique within + * the process. \p data must point to an unsigned long long. + * + * \p ptr must be a pointer to memory obtained from a CUDA memory + * allocation API. Every memory allocation from any of the CUDA memory + * allocation APIs will have a unique ID over a process lifetime. Subsequent + * allocations do not reuse IDs from previous freed allocations. IDs are only + * unique within a single process. + * + * + * - ::CU_POINTER_ATTRIBUTE_IS_MANAGED: + * + * Returns in \p *data a boolean that indicates whether the pointer points + * to managed memory or not. + * + * If \p ptr is not a valid CUDA pointer then ::CUDA_ERROR_INVALID_VALUE is + * returned. + * + * - ::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL: + * + * Returns in \p *data an integer representing a device ordinal of a device + * against which the memory was allocated or registered. + * + * - ::CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE: + * + * Returns in \p *data a boolean that indicates if this pointer maps to + * an allocation that is suitable for ::cudaIpcGetMemHandle. + * + * - ::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR: + * + * Returns in \p *data the starting address for the allocation referenced + * by the device pointer \p ptr. Note that this is not necessarily the + * address of the mapped region, but the address of the mappable address + * range \p ptr references (e.g. from ::cuMemAddressReserve). + * + * - ::CU_POINTER_ATTRIBUTE_RANGE_SIZE: + * + * Returns in \p *data the size for the allocation referenced by the device + * pointer \p ptr. Note that this is not necessarily the size of the + * mapped region, but the size of the mappable address range \p ptr references + * (e.g. from ::cuMemAddressReserve). To retrieve the size of the mapped + * region, see ::cuMemGetAddressRange + * + * - ::CU_POINTER_ATTRIBUTE_MAPPED: + * + * Returns in \p *data a boolean that indicates if this pointer is in a + * valid address range that is mapped to a backing allocation. + * + * - ::CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES: + * + * Returns a bitmask of the allowed handle types for an allocation that may + * be passed to ::cuMemExportToShareableHandle. + * + * - ::CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE: + * + * Returns in \p *data the handle to the mempool that the allocation was + * obtained from. + * + * \par + * + * Note that for most allocations in the unified virtual address space + * the host and device pointer for accessing the allocation will be the + * same. The exceptions to this are + * - user memory registered using ::cuMemHostRegister + * - host memory allocated using ::cuMemHostAlloc with the + * ::CU_MEMHOSTALLOC_WRITECOMBINED flag + * For these types of allocation there will exist separate, disjoint host + * and device addresses for accessing the allocation. In particular + * - The host address will correspond to an invalid unmapped device address + * (which will result in an exception if accessed from the device) + * - The device address will correspond to an invalid unmapped host address + * (which will result in an exception if accessed from the host). + * For these types of allocations, querying ::CU_POINTER_ATTRIBUTE_HOST_POINTER + * and ::CU_POINTER_ATTRIBUTE_DEVICE_POINTER may be used to retrieve the host + * and device addresses from either address. + * + * \param data - Returned pointer attribute value + * \param attribute - Pointer attribute to query + * \param ptr - Pointer + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuPointerSetAttribute, + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuMemAllocHost, + * ::cuMemFreeHost, + * ::cuMemHostAlloc, + * ::cuMemHostRegister, + * ::cuMemHostUnregister, + * ::cudaPointerGetAttributes + */ +CUresult CUDAAPI cuPointerGetAttribute(void *data, + CUpointer_attribute attribute, + CUdeviceptr ptr); + +/** + * \brief Prefetches memory to the specified destination device + * + * Note there is a later version of this API, ::cuMemPrefetchAsync_v2. It will + * supplant this version in 13.0, which is retained for minor version + * compatibility. + * + * Prefetches memory to the specified destination device. \p devPtr is the + * base device pointer of the memory to be prefetched and \p dstDevice is the + * destination device. \p count specifies the number of bytes to copy. \p + * hStream is the stream in which the operation is enqueued. The memory range + * must refer to managed memory allocated via ::cuMemAllocManaged or declared + * via __managed__ variables. + * + * Passing in CU_DEVICE_CPU for \p dstDevice will prefetch the data to host + * memory. If \p dstDevice is a GPU, then the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS must be non-zero. + * Additionally, \p hStream must be associated with a device that has a non-zero + * value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. + * + * The start address and end address of the memory range will be rounded down + * and rounded up respectively to be aligned to CPU page size before the + * prefetch operation is enqueued in the stream. + * + * If no physical memory has been allocated for this region, then this memory + * region will be populated and mapped on the destination device. If there's + * insufficient memory to prefetch the desired region, the Unified Memory driver + * may evict pages from other + * ::cuMemAllocManaged allocations to host memory in order to make room. Device + * memory allocated using ::cuMemAlloc or ::cuArrayCreate will not be evicted. + * + * By default, any mappings to the previous location of the migrated pages are + * removed and mappings for the new location are only setup on \p dstDevice. The + * exact behavior however also depends on the settings applied to this memory + * range via ::cuMemAdvise as described below: + * + * If ::CU_MEM_ADVISE_SET_READ_MOSTLY was set on any subset of this memory + * range, then that subset will create a read-only copy of the pages on \p + * dstDevice. + * + * If ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION was called on any subset of this + * memory range, then the pages will be migrated to \p dstDevice even if \p + * dstDevice is not the preferred location of any pages in the memory range. + * + * If ::CU_MEM_ADVISE_SET_ACCESSED_BY was called on any subset of this memory + * range, then mappings to those pages from all the appropriate processors are + * updated to refer to the new location if establishing such a mapping is + * possible. Otherwise, those mappings are cleared. + * + * Note that this API is not required for functionality and only serves to + * improve performance by allowing the application to migrate data to a suitable + * location before it is accessed. Memory accesses to this range are always + * coherent and are allowed even when the data is actively being migrated. + * + * Note that this function is asynchronous with respect to the host and all work + * on other devices. + * + * \param devPtr - Pointer to be prefetched + * \param count - Size in bytes + * \param dstDevice - Destination device to prefetch to + * \param hStream - Stream to enqueue prefetch operation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpy, ::cuMemcpyPeer, ::cuMemcpyAsync, + * ::cuMemcpy3DPeerAsync, ::cuMemAdvise, ::cuMemPrefetchAsync + * ::cudaMemPrefetchAsync_v2 + */ +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, + CUdevice dstDevice, CUstream hStream); + +/** + * \brief Prefetches memory to the specified destination location + * + * Prefetches memory to the specified destination location. \p devPtr is the + * base device pointer of the memory to be prefetched and \p location specifies + * the destination location. \p count specifies the number of bytes to copy. \p + * hStream is the stream in which the operation is enqueued. The memory range + * must refer to managed memory allocated via ::cuMemAllocManaged or declared + * via __managed__ variables. + * + * Specifying ::CU_MEM_LOCATION_TYPE_DEVICE for ::CUmemLocation::type will + * prefetch memory to GPU specified by device ordinal ::CUmemLocation::id which + * must have non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Additionally, \p hStream + * must be associated with a device that has a non-zero value for the device + * attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Specifying + * ::CU_MEM_LOCATION_TYPE_HOST as ::CUmemLocation::type will prefetch data to + * host memory. Applications can request prefetching memory to a specific host + * NUMA node by specifying + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA for ::CUmemLocation::type and a valid host + * NUMA node id in ::CUmemLocation::id Users can also request prefetching memory + * to the host NUMA node closest to the current thread's CPU by specifying + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT for ::CUmemLocation::type. Note when + * ::CUmemLocation::type is either + * ::CU_MEM_LOCATION_TYPE_HOST OR ::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT, + * ::CUmemLocation::id will be ignored. + * + * The start address and end address of the memory range will be rounded down + * and rounded up respectively to be aligned to CPU page size before the + * prefetch operation is enqueued in the stream. + * + * If no physical memory has been allocated for this region, then this memory + * region will be populated and mapped on the destination device. If there's + * insufficient memory to prefetch the desired region, the Unified Memory driver + * may evict pages from other + * ::cuMemAllocManaged allocations to host memory in order to make room. Device + * memory allocated using ::cuMemAlloc or ::cuArrayCreate will not be evicted. + * + * By default, any mappings to the previous location of the migrated pages are + * removed and mappings for the new location are only setup on the destination + * location. The exact behavior however also depends on the settings applied to + * this memory range via ::cuMemAdvise as described below: + * + * If ::CU_MEM_ADVISE_SET_READ_MOSTLY was set on any subset of this memory + * range, then that subset will create a read-only copy of the pages on + * destination location. If however the destination location is a host NUMA + * node, then any pages of that subset that are already in another host NUMA + * node will be transferred to the destination. + * + * If ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION was called on any subset of this + * memory range, then the pages will be migrated to \p location even if \p + * location is not the preferred location of any pages in the memory range. + * + * If ::CU_MEM_ADVISE_SET_ACCESSED_BY was called on any subset of this memory + * range, then mappings to those pages from all the appropriate processors are + * updated to refer to the new location if establishing such a mapping is + * possible. Otherwise, those mappings are cleared. + * + * Note that this API is not required for functionality and only serves to + * improve performance by allowing the application to migrate data to a suitable + * location before it is accessed. Memory accesses to this range are always + * coherent and are allowed even when the data is actively being migrated. + * + * Note that this function is asynchronous with respect to the host and all work + * on other devices. + * + * \param devPtr - Pointer to be prefetched + * \param count - Size in bytes + * \param dstDevice - Destination device to prefetch to + * \param flags - flags for future use, must be zero now. + * \param hStream - Stream to enqueue prefetch operation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpy, ::cuMemcpyPeer, ::cuMemcpyAsync, + * ::cuMemcpy3DPeerAsync, ::cuMemAdvise, ::cuMemPrefetchAsync + * ::cudaMemPrefetchAsync_v2 + */ +CUresult CUDAAPI cuMemPrefetchAsync_v2(CUdeviceptr devPtr, size_t count, + CUmemLocation location, + unsigned int flags, CUstream hStream); + +/** + * \brief Advise about the usage of a given memory range + * + * Note there is a later version of this API, ::cuMemAdvise_v2. It will + * supplant this version in 13.0, which is retained for minor version + * compatibility. + * + * Advise the Unified Memory subsystem about the usage pattern for the memory + * range starting at \p devPtr with a size of \p count bytes. The start address + * and end address of the memory range will be rounded down and rounded up + * respectively to be aligned to CPU page size before the advice is applied. The + * memory range must refer to managed memory allocated via ::cuMemAllocManaged + * or declared via __managed__ variables. The memory range could also refer to + * system-allocated pageable memory provided it represents a valid, + * host-accessible region of memory and all additional constraints imposed by \p + * advice as outlined below are also satisfied. Specifying an invalid + * system-allocated pageable memory range results in an error being returned. + * + * The \p advice parameter can take the following values: + * - ::CU_MEM_ADVISE_SET_READ_MOSTLY: This implies that the data is mostly going + * to be read from and only occasionally written to. Any read accesses from any + * processor to this region will create a read-only copy of at least the + * accessed pages in that processor's memory. Additionally, if + * ::cuMemPrefetchAsync is called on this region, it will create a read-only + * copy of the data on the destination processor. If any processor writes to + * this region, all copies of the corresponding page will be invalidated except + * for the one where the write occurred. The \p device argument is ignored for + * this advice. Note that for a page to be read-duplicated, the accessing + * processor must either be the CPU or a GPU that has a non-zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Also, if a + * context is created on a device that does not have the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS set, then read-duplication + * will not occur until all such contexts are destroyed. If the memory region + * refers to valid system-allocated pageable memory, then the accessing device + * must have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS for a read-only copy to be + * created on that device. Note however that if the accessing device also has a + * non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then + * setting this advice will not create a read-only copy when that device + * accesses this memory region. + * + * - ::CU_MEM_ADVISE_UNSET_READ_MOSTLY: Undoes the effect of + * ::CU_MEM_ADVISE_SET_READ_MOSTLY and also prevents the Unified Memory driver + * from attempting heuristic read-duplication on the memory range. Any + * read-duplicated copies of the data will be collapsed into a single copy. The + * location for the collapsed copy will be the preferred location if the page + * has a preferred location and one of the read-duplicated copies was resident + * at that location. Otherwise, the location chosen is arbitrary. + * + * - ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION: This advice sets the preferred + * location for the data to be the memory belonging to \p device. Passing in + * CU_DEVICE_CPU for \p device sets the preferred location as host memory. If \p + * device is a GPU, then it must have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Setting the preferred + * location does not cause data to migrate to that location immediately. + * Instead, it guides the migration policy when a fault occurs on that memory + * region. If the data is already in its preferred location and the faulting + * processor can establish a mapping without requiring the data to be migrated, + * then data migration will be avoided. On the other hand, if the data is not in + * its preferred location or if a direct mapping cannot be established, then it + * will be migrated to the processor accessing it. It is important to note that + * setting the preferred location does not prevent data prefetching done using + * ::cuMemPrefetchAsync. Having a preferred location can override the page + * thrash detection and resolution logic in the Unified Memory driver. Normally, + * if a page is detected to be constantly thrashing between for example host and + * device memory, the page may eventually be pinned to host memory by the + * Unified Memory driver. But if the preferred location is set as device memory, + * then the page will continue to thrash indefinitely. If + * ::CU_MEM_ADVISE_SET_READ_MOSTLY is also set on this memory region or any + * subset of it, then the policies associated with that advice will override the + * policies of this advice, unless read accesses from \p device will not result + * in a read-only copy being created on that device as outlined in description + * for the advice ::CU_MEM_ADVISE_SET_READ_MOSTLY. If the memory region refers + * to valid system-allocated pageable memory, then \p device must have a + * non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. + * + * - ::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION: Undoes the effect of + * ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION and changes the preferred location to + * none. + * + * - ::CU_MEM_ADVISE_SET_ACCESSED_BY: This advice implies that the data will be + * accessed by \p device. Passing in ::CU_DEVICE_CPU for \p device will set the + * advice for the CPU. If \p device is a GPU, then the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS must be non-zero. This advice + * does not cause data migration and has no impact on the location of the data + * per se. Instead, it causes the data to always be mapped in the specified + * processor's page tables, as long as the location of the data permits a + * mapping to be established. If the data gets migrated for any reason, the + * mappings are updated accordingly. This advice is recommended in scenarios + * where data locality is not important, but avoiding faults is. Consider for + * example a system containing multiple GPUs with peer-to-peer access enabled, + * where the data located on one GPU is occasionally accessed by peer GPUs. In + * such scenarios, migrating data over to the other GPUs is not as important + * because the accesses are infrequent and the overhead of migration may be too + * high. But preventing faults can still help improve performance, and so having + * a mapping set up in advance is useful. Note that on CPU access of this data, + * the data may be migrated to host memory because the CPU typically cannot + * access device memory directly. Any GPU that had the + * ::CU_MEM_ADVISE_SET_ACCESSED_BY flag set for this data will now have its + * mapping updated to point to the page in host memory. If + * ::CU_MEM_ADVISE_SET_READ_MOSTLY is also set on this memory region or any + * subset of it, then the policies associated with that advice will override the + * policies of this advice. Additionally, if the preferred location of this + * memory region or any subset of it is also \p device, then the policies + * associated with ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION will override the + * policies of this advice. If the memory region refers to valid + * system-allocated pageable memory, then \p device must have a non-zero value + * for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. + * Additionally, if \p device has a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then this + * call has no effect. + * + * - ::CU_MEM_ADVISE_UNSET_ACCESSED_BY: Undoes the effect of + * ::CU_MEM_ADVISE_SET_ACCESSED_BY. Any mappings to the data from \p device may + * be removed at any time causing accesses to result in non-fatal page faults. + * If the memory region refers to valid system-allocated pageable memory, then + * \p device must have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. Additionally, if \p device has + * a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then this + * call has no effect. + * + * \param devPtr - Pointer to memory to set the advice for + * \param count - Size in bytes of the memory range + * \param advice - Advice to be applied for the specified memory range + * \param device - Device to apply the advice for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpy, ::cuMemcpyPeer, ::cuMemcpyAsync, + * ::cuMemcpy3DPeerAsync, ::cuMemPrefetchAsync, ::cuMemAdvise_v2 + * ::cudaMemAdvise + */ +CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, + CUmem_advise advice, CUdevice device); + +/** + * \brief Advise about the usage of a given memory range + * + * Advise the Unified Memory subsystem about the usage pattern for the memory + * range starting at \p devPtr with a size of \p count bytes. The start address + * and end address of the memory range will be rounded down and rounded up + * respectively to be aligned to CPU page size before the advice is applied. The + * memory range must refer to managed memory allocated via ::cuMemAllocManaged + * or declared via __managed__ variables. The memory range could also refer to + * system-allocated pageable memory provided it represents a valid, + * host-accessible region of memory and all additional constraints imposed by \p + * advice as outlined below are also satisfied. Specifying an invalid + * system-allocated pageable memory range results in an error being returned. + * + * The \p advice parameter can take the following values: + * - ::CU_MEM_ADVISE_SET_READ_MOSTLY: This implies that the data is mostly going + * to be read from and only occasionally written to. Any read accesses from any + * processor to this region will create a read-only copy of at least the + * accessed pages in that processor's memory. Additionally, if + * ::cuMemPrefetchAsync or ::cuMemPrefetchAsync_v2 is called on this region, it + * will create a read-only copy of the data on the destination processor. If the + * target location for ::cuMemPrefetchAsync_v2 is a host NUMA node and a + * read-only copy already exists on another host NUMA node, that copy will be + * migrated to the targeted host NUMA node. If any processor writes to this + * region, all copies of the corresponding page will be invalidated except for + * the one where the write occurred. If the writing processor is the CPU and the + * preferred location of the page is a host NUMA node, then the page will also + * be migrated to that host NUMA node. The \p location argument is ignored for + * this advice. Note that for a page to be read-duplicated, the accessing + * processor must either be the CPU or a GPU that has a non-zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Also, if a + * context is created on a device that does not have the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS set, then read-duplication + * will not occur until all such contexts are destroyed. If the memory region + * refers to valid system-allocated pageable memory, then the accessing device + * must have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS for a read-only copy to be + * created on that device. Note however that if the accessing device also has a + * non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then + * setting this advice will not create a read-only copy when that device + * accesses this memory region. + * + * - ::CU_MEM_ADVISE_UNSET_READ_MOSTLY: Undoes the effect of + * ::CU_MEM_ADVISE_SET_READ_MOSTLY and also prevents the Unified Memory driver + * from attempting heuristic read-duplication on the memory range. Any + * read-duplicated copies of the data will be collapsed into a single copy. The + * location for the collapsed copy will be the preferred location if the page + * has a preferred location and one of the read-duplicated copies was resident + * at that location. Otherwise, the location chosen is arbitrary. Note: The \p + * location argument is ignored for this advice. + * + * - ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION: This advice sets the preferred + * location for the data to be the memory belonging to \p location. When + * ::CUmemLocation::type is ::CU_MEM_LOCATION_TYPE_HOST, + * ::CUmemLocation::id is ignored and the preferred location is set to be host + * memory. To set the preferred location to a specific host NUMA node, + * applications must set ::CUmemLocation::type to + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA and + * ::CUmemLocation::id must specify the NUMA ID of the host NUMA node. If + * ::CUmemLocation::type is set to ::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT, + * ::CUmemLocation::id will be ignored and the the host NUMA node closest to the + * calling thread's CPU will be used as the preferred location. If + * ::CUmemLocation::type is a ::CU_MEM_LOCATION_TYPE_DEVICE, then + * ::CUmemLocation::id must be a valid device ordinal and the device must have a + * non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Setting the preferred + * location does not cause data to migrate to that location immediately. + * Instead, it guides the migration policy when a fault occurs on that memory + * region. If the data is already in its preferred location and the faulting + * processor can establish a mapping without requiring the data to be migrated, + * then data migration will be avoided. On the other hand, if the data is not in + * its preferred location or if a direct mapping cannot be established, then it + * will be migrated to the processor accessing it. It is important to note that + * setting the preferred location does not prevent data prefetching done using + * ::cuMemPrefetchAsync. Having a preferred location can override the page + * thrash detection and resolution logic in the Unified Memory driver. Normally, + * if a page is detected to be constantly thrashing between for example host and + * device memory, the page may eventually be pinned to host memory by the + * Unified Memory driver. But if the preferred location is set as device memory, + * then the page will continue to thrash indefinitely. If + * ::CU_MEM_ADVISE_SET_READ_MOSTLY is also set on this memory region or any + * subset of it, then the policies associated with that advice will override the + * policies of this advice, unless read accesses from \p location will not + * result in a read-only copy being created on that processor as outlined in + * description for the advice ::CU_MEM_ADVISE_SET_READ_MOSTLY. If the memory + * region refers to valid system-allocated pageable memory, and + * ::CUmemLocation::type is CU_MEM_LOCATION_TYPE_DEVICE then ::CUmemLocation::id + * must be a valid device that has a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. + * + * - ::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION: Undoes the effect of + * ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION and changes the preferred location to + * none. The \p location argument is ignored for this advice. + * + * - ::CU_MEM_ADVISE_SET_ACCESSED_BY: This advice implies that the data will be + * accessed by processor \p location. The ::CUmemLocation::type must be either + * ::CU_MEM_LOCATION_TYPE_DEVICE with ::CUmemLocation::id representing a valid + * device ordinal or ::CU_MEM_LOCATION_TYPE_HOST and ::CUmemLocation::id will be + * ignored. All other location types are invalid. If ::CUmemLocation::id is a + * GPU, then the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS must be non-zero. This advice + * does not cause data migration and has no impact on the location of the data + * per se. Instead, it causes the data to always be mapped in the specified + * processor's page tables, as long as the location of the data permits a + * mapping to be established. If the data gets migrated for any reason, the + * mappings are updated accordingly. This advice is recommended in scenarios + * where data locality is not important, but avoiding faults is. Consider for + * example a system containing multiple GPUs with peer-to-peer access enabled, + * where the data located on one GPU is occasionally accessed by peer GPUs. In + * such scenarios, migrating data over to the other GPUs is not as important + * because the accesses are infrequent and the overhead of migration may be too + * high. But preventing faults can still help improve performance, and so having + * a mapping set up in advance is useful. Note that on CPU access of this data, + * the data may be migrated to host memory because the CPU typically cannot + * access device memory directly. Any GPU that had the + * ::CU_MEM_ADVISE_SET_ACCESSED_BY flag set for this data will now have its + * mapping updated to point to the page in host memory. If + * ::CU_MEM_ADVISE_SET_READ_MOSTLY is also set on this memory region or any + * subset of it, then the policies associated with that advice will override the + * policies of this advice. Additionally, if the preferred location of this + * memory region or any subset of it is also \p location, then the policies + * associated with ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION will override the + * policies of this advice. If the memory region refers to valid + * system-allocated pageable memory, and ::CUmemLocation::type is + * ::CU_MEM_LOCATION_TYPE_DEVICE then device in ::CUmemLocation::id must have a + * non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. Additionally, if + * ::CUmemLocation::id has a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then this + * call has no effect. + * + * - ::CU_MEM_ADVISE_UNSET_ACCESSED_BY: Undoes the effect of + * ::CU_MEM_ADVISE_SET_ACCESSED_BY. Any mappings to the data from \p location + * may be removed at any time causing accesses to result in non-fatal page + * faults. If the memory region refers to valid system-allocated pageable + * memory, and ::CUmemLocation::type is ::CU_MEM_LOCATION_TYPE_DEVICE then + * device in ::CUmemLocation::id must have a non-zero value for the device + * attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. Additionally, if + * ::CUmemLocation::id has a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then this + * call has no effect. + * + * \param devPtr - Pointer to memory to set the advice for + * \param count - Size in bytes of the memory range + * \param advice - Advice to be applied for the specified memory range + * \param location - location to apply the advice for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpy, ::cuMemcpyPeer, ::cuMemcpyAsync, + * ::cuMemcpy3DPeerAsync, ::cuMemPrefetchAsync, ::cuMemAdvise + * ::cudaMemAdvise + */ +CUresult CUDAAPI cuMemAdvise_v2(CUdeviceptr devPtr, size_t count, + CUmem_advise advice, CUmemLocation location); + +/** + * \brief Query an attribute of a given memory range + * + * Query an attribute about the memory range starting at \p devPtr with a size + * of \p count bytes. The memory range must refer to managed memory allocated + * via ::cuMemAllocManaged or declared via + * __managed__ variables. + * + * The \p attribute parameter can take the following values: + * - ::CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY: If this attribute is specified, \p + * data will be interpreted as a 32-bit integer, and \p dataSize must be 4. The + * result returned will be 1 if all pages in the given memory range have + * read-duplication enabled, or 0 otherwise. + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION: If this attribute is + * specified, \p data will be interpreted as a 32-bit integer, and \p dataSize + * must be 4. The result returned will be a GPU device id if all pages in the + * memory range have that GPU as their preferred location, or it will be + * CU_DEVICE_CPU if all pages in the memory range have the CPU as their + * preferred location, or it will be CU_DEVICE_INVALID if either all the pages + * don't have the same preferred location or some of the pages don't have a + * preferred location at all. Note that the actual location of the pages in the + * memory range at the time of the query may be different from the preferred + * location. + * - ::CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY: If this attribute is specified, \p + * data will be interpreted as an array of 32-bit integers, and \p dataSize must + * be a non-zero multiple of 4. The result returned will be a list of device ids + * that had ::CU_MEM_ADVISE_SET_ACCESSED_BY set for that entire memory range. If + * any device does not have that advice set for the entire memory range, that + * device will not be included. If \p data is larger than the number of devices + * that have that advice set for that memory range, CU_DEVICE_INVALID will be + * returned in all the extra space provided. For ex., if \p dataSize is 12 (i.e. + * \p data has 3 elements) and only device 0 has the advice set, then the result + * returned will be { 0, CU_DEVICE_INVALID, CU_DEVICE_INVALID }. If \p data is + * smaller than the number of devices that have that advice set, then only as + * many devices will be returned as can fit in the array. There is no guarantee + * on which specific devices will be returned, however. + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION: If this attribute is + * specified, \p data will be interpreted as a 32-bit integer, and \p dataSize + * must be 4. The result returned will be the last location to which all pages + * in the memory range were prefetched explicitly via ::cuMemPrefetchAsync. This + * will either be a GPU id or CU_DEVICE_CPU depending on whether the last + * location for prefetch was a GPU or the CPU respectively. If any page in the + * memory range was never explicitly prefetched or if all pages were not + * prefetched to the same location, CU_DEVICE_INVALID will be returned. Note + * that this simply returns the last location that the application requested to + * prefetch the memory range to. It gives no indication as to whether the + * prefetch operation to that location has completed or even begun. + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_TYPE: If this attribute is + * specified, \p data will be interpreted as a ::CUmemLocationType, and \p + * dataSize must be sizeof(CUmemLocationType). The ::CUmemLocationType returned + * will be + * ::CU_MEM_LOCATION_TYPE_DEVICE if all pages in the memory range have the same + * GPU as their preferred location, or ::CUmemLocationType will be + * ::CU_MEM_LOCATION_TYPE_HOST if all pages in the memory range have the CPU as + * their preferred location, or it will be ::CU_MEM_LOCATION_TYPE_HOST_NUMA if + * all the pages in the memory range have the same host NUMA node ID as their + * preferred location or it will be ::CU_MEM_LOCATION_TYPE_INVALID if either all + * the pages don't have the same preferred location or some of the pages don't + * have a preferred location at all. Note that the actual location type of the + * pages in the memory range at the time of the query may be different from the + * preferred location type. + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_ID: If this attribute is + * specified, \p data will be interpreted as a 32-bit integer, and \p dataSize + * must be 4. If the ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_TYPE query for + * the same address range returns ::CU_MEM_LOCATION_TYPE_DEVICE, it will be a + * valid device ordinal or if it returns ::CU_MEM_LOCATION_TYPE_HOST_NUMA, it + * will be a valid host NUMA node ID or if it returns any other location type, + * the id should be ignored. + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_TYPE: If this attribute is + * specified, \p data will be interpreted as a ::CUmemLocationType, and \p + * dataSize must be sizeof(CUmemLocationType). The result returned will be the + * last location to which all pages in the memory range were prefetched + * explicitly via ::cuMemPrefetchAsync. The ::CUmemLocationType returned will be + * ::CU_MEM_LOCATION_TYPE_DEVICE if the last prefetch location was a GPU or + * ::CU_MEM_LOCATION_TYPE_HOST if it was the CPU or + * ::CU_MEM_LOCATION_TYPE_HOST_NUMA if the last prefetch location was a specific + * host NUMA node. If any page in the memory range was never explicitly + * prefetched or if all pages were not prefetched to the same location, + * ::CUmemLocationType will be ::CU_MEM_LOCATION_TYPE_INVALID. Note that this + * simply returns the last location type that the application requested to + * prefetch the memory range to. It gives no indication as to whether the + * prefetch operation to that location has completed or even begun. + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_ID: If this attribute is + * specified, \p data will be interpreted as a 32-bit integer, and \p dataSize + * must be 4. If the ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_TYPE query + * for the same address range returns ::CU_MEM_LOCATION_TYPE_DEVICE, it will be + * a valid device ordinal or if it returns ::CU_MEM_LOCATION_TYPE_HOST_NUMA, it + * will be a valid host NUMA node ID or if it returns any other location type, + * the id should be ignored. + * + * \param data - A pointers to a memory location where the result + * of each attribute query will be written to. + * \param dataSize - Array containing the size of data + * \param attribute - The attribute to query + * \param devPtr - Start of the range to query + * \param count - Size of the range to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemRangeGetAttributes, ::cuMemPrefetchAsync, + * ::cuMemAdvise, + * ::cudaMemRangeGetAttribute + */ +CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, + CUmem_range_attribute attribute, + CUdeviceptr devPtr, size_t count); + +/** + * \brief Query attributes of a given memory range. + * + * Query attributes of the memory range starting at \p devPtr with a size of \p + * count bytes. The memory range must refer to managed memory allocated via + * ::cuMemAllocManaged or declared via + * __managed__ variables. The \p attributes array will be interpreted to have \p + * numAttributes entries. The \p dataSizes array will also be interpreted to + * have \p numAttributes entries. The results of the query will be stored in \p + * data. + * + * The list of supported attributes are given below. Please refer to + * ::cuMemRangeGetAttribute for attribute descriptions and restrictions. + * + * - ::CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION + * - ::CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_TYPE + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_ID + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_TYPE + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION_ID + * + * \param data - A two-dimensional array containing pointers to memory + * locations where the result of each attribute query + * will be written to. \param dataSizes - Array containing the sizes of each + * result \param attributes - An array of attributes to query (numAttributes + * and the number of attributes in this array should match) \param numAttributes + * - Number of attributes to query \param devPtr - Start of the range to + * query \param count - Size of the range to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa ::cuMemRangeGetAttribute, ::cuMemAdvise, + * ::cuMemPrefetchAsync, + * ::cudaMemRangeGetAttributes + */ +CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, + CUmem_range_attribute *attributes, + size_t numAttributes, + CUdeviceptr devPtr, size_t count); + +/** + * \brief Set attributes on a previously allocated memory region + * + * The supported attributes are: + * + * - ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS: + * + * A boolean attribute that can either be set (1) or unset (0). When set, + * the region of memory that \p ptr points to is guaranteed to always + * synchronize memory operations that are synchronous. If there are some + * previously initiated synchronous memory operations that are pending when this + * attribute is set, the function does not return until those memory operations + * are complete. See further documentation in the section titled "API + * synchronization behavior" to learn more about cases when synchronous memory + * operations can exhibit asynchronous behavior. \p value will be considered as + * a pointer to an unsigned integer to which this attribute is to be set. + * + * \param value - Pointer to memory containing the value to be set + * \param attribute - Pointer attribute to set + * \param ptr - Pointer to a memory region allocated using CUDA memory + * allocation APIs + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa ::cuPointerGetAttribute, + * ::cuPointerGetAttributes, + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuMemAllocHost, + * ::cuMemFreeHost, + * ::cuMemHostAlloc, + * ::cuMemHostRegister, + * ::cuMemHostUnregister + */ +CUresult CUDAAPI cuPointerSetAttribute(const void *value, + CUpointer_attribute attribute, + CUdeviceptr ptr); + +/** + * \brief Returns information about a pointer. + * + * The supported attributes are (refer to ::cuPointerGetAttribute for attribute + * descriptions and restrictions): + * + * - ::CU_POINTER_ATTRIBUTE_CONTEXT + * - ::CU_POINTER_ATTRIBUTE_MEMORY_TYPE + * - ::CU_POINTER_ATTRIBUTE_DEVICE_POINTER + * - ::CU_POINTER_ATTRIBUTE_HOST_POINTER + * - ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS + * - ::CU_POINTER_ATTRIBUTE_BUFFER_ID + * - ::CU_POINTER_ATTRIBUTE_IS_MANAGED + * - ::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL + * - ::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR + * - ::CU_POINTER_ATTRIBUTE_RANGE_SIZE + * - ::CU_POINTER_ATTRIBUTE_MAPPED + * - ::CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE + * - ::CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES + * - ::CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE + * + * \param numAttributes - Number of attributes to query + * \param attributes - An array of attributes to query + * (numAttributes and the number of attributes in this + * array should match) \param data - A two-dimensional array containing + * pointers to memory locations where the result of each attribute query will be + * written to. \param ptr - Pointer to query + * + * Unlike ::cuPointerGetAttribute, this function will not return an error when + * the \p ptr encountered is not a valid CUDA pointer. Instead, the attributes + * are assigned default NULL values and CUDA_SUCCESS is returned. + * + * If \p ptr was not allocated by, mapped by, or registered with a ::CUcontext + * which uses UVA (Unified Virtual Addressing), ::CUDA_ERROR_INVALID_CONTEXT is + * returned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuPointerGetAttribute, + * ::cuPointerSetAttribute, + * ::cudaPointerGetAttributes + */ +CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, + CUpointer_attribute *attributes, + void **data, CUdeviceptr ptr); + +/** @} */ /* END CUDA_UNIFIED */ + +/** + * \defgroup CUDA_STREAM Stream Management + * + * ___MANBRIEF___ stream management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Create a stream + * + * Creates a stream and returns a handle in \p phStream. The \p Flags argument + * determines behaviors of the stream. + * + * Valid values for \p Flags are: + * - ::CU_STREAM_DEFAULT: Default stream creation flag. + * - ::CU_STREAM_NON_BLOCKING: Specifies that work running in the created + * stream may run concurrently with work in stream 0 (the NULL stream), and + * that the created stream should perform no implicit synchronization with + * stream 0. + * + * \param phStream - Returned newly created stream + * \param Flags - Parameters for stream creation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreate, + * ::cudaStreamCreateWithFlags + */ +CUresult CUDAAPI cuStreamCreate(CUstream *phStream, unsigned int Flags); + +/** + * \brief Create a stream with the given priority + * + * Creates a stream with the specified priority and returns a handle in \p + * phStream. This affects the scheduling priority of work in the stream. + * Priorities provide a hint to preferentially run work with higher priority + * when possible, but do not preempt already-running work or provide any other + * functional guarantee on execution order. + * + * \p priority follows a convention where lower numbers represent higher + * priorities. '0' represents default priority. The range of meaningful + * numerical priorities can be queried using ::cuCtxGetStreamPriorityRange. If + * the specified priority is outside the numerical range returned by + * ::cuCtxGetStreamPriorityRange, it will automatically be clamped to the lowest + * or the highest number in the range. + * + * \param phStream - Returned newly created stream + * \param flags - Flags for stream creation. See ::cuStreamCreate for a + * list of valid flags \param priority - Stream priority. Lower numbers + * represent higher priorities. See ::cuCtxGetStreamPriorityRange for more + * information about meaningful stream priorities that can be passed. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \note Stream priorities are supported only on GPUs + * with compute capability 3.5 or higher. + * + * \note In the current implementation, only compute kernels launched in + * priority streams are affected by the stream's priority. Stream priorities + * have no effect on host-to-device and device-to-host memory operations. + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamGetPriority, + * ::cuCtxGetStreamPriorityRange, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreateWithPriority + */ +CUresult CUDAAPI cuStreamCreateWithPriority(CUstream *phStream, + unsigned int flags, int priority); + +/** + * \brief Query the priority of a given stream + * + * Query the priority of a stream created using ::cuStreamCreate or + * ::cuStreamCreateWithPriority and return the priority in \p priority. Note + * that if the stream was created with a priority outside the numerical range + * returned by ::cuCtxGetStreamPriorityRange, this function returns the clamped + * priority. See ::cuStreamCreateWithPriority for details about priority + * clamping. + * + * \param hStream - Handle to the stream to be queried + * \param priority - Pointer to a signed integer in which the stream's + * priority is returned \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamCreateWithPriority, + * ::cuCtxGetStreamPriorityRange, + * ::cuStreamGetFlags, + * ::cudaStreamGetPriority + */ +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); + +/** + * \brief Query the flags of a given stream + * + * Query the flags of a stream created using ::cuStreamCreate or + * ::cuStreamCreateWithPriority and return the flags in \p flags. + * + * \param hStream - Handle to the stream to be queried + * \param flags - Pointer to an unsigned integer in which the stream's + * flags are returned The value returned in \p flags is a logical 'OR' of all + * flags that were used while creating this stream. See ::cuStreamCreate for the + * list of valid flags \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamGetPriority, + * ::cudaStreamGetFlags + */ +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); + +/** + * \brief Returns the unique Id associated with the stream handle supplied + * + * Returns in \p streamId the unique Id which is associated with the given + * stream handle. The Id is unique for the life of the program. + * + * The stream handle \p hStream can refer to any of the following: + *
    + *
  • a stream created via any of the CUDA driver APIs such as + * ::cuStreamCreate and ::cuStreamCreateWithPriority, or their runtime API + * equivalents such as + * ::cudaStreamCreate, ::cudaStreamCreateWithFlags and + * ::cudaStreamCreateWithPriority. Passing an invalid handle will result in + * undefined behavior.
  • any of the special streams such as the NULL + * stream, ::CU_STREAM_LEGACY and + * ::CU_STREAM_PER_THREAD. The runtime API equivalents of these are also + * accepted, which are NULL, ::cudaStreamLegacy and ::cudaStreamPerThread + * respectively.
  • + *
+ * + * \param hStream - Handle to the stream to be queried + * \param streamId - Pointer to store the Id of the stream + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamGetPriority, + * ::cudaStreamGetId + */ +CUresult CUDAAPI cuStreamGetId(CUstream hStream, unsigned long long *streamId); + +/** + * \brief Query the context associated with a stream + * + * Returns the CUDA context that the stream is associated with. + * + * The stream handle \p hStream can refer to any of the following: + *
    + *
  • a stream created via any of the CUDA driver APIs such as + * ::cuStreamCreate and ::cuStreamCreateWithPriority, or their runtime API + * equivalents such as + * ::cudaStreamCreate, ::cudaStreamCreateWithFlags and + * ::cudaStreamCreateWithPriority. The returned context is the context that was + * active in the calling thread when the stream was created. Passing an invalid + * handle will result in undefined behavior.
  • any of the special streams + * such as the NULL stream, ::CU_STREAM_LEGACY and + * ::CU_STREAM_PER_THREAD. The runtime API equivalents of these are also + * accepted, which are NULL, ::cudaStreamLegacy and ::cudaStreamPerThread + * respectively. Specifying any of the special handles will return the context + * current to the calling thread. If no context is current to the calling + * thread, + * ::CUDA_ERROR_INVALID_CONTEXT is returned.
  • + *
+ * + * \param hStream - Handle to the stream to be queried + * \param pctx - Returned context associated with the stream + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreate, + * ::cudaStreamCreateWithFlags + */ +CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); + +/** + * \brief Make a compute stream wait on an event + * + * Makes all future work submitted to \p hStream wait for all work captured in + * \p hEvent. See ::cuEventRecord() for details on what is captured by an + * event. The synchronization will be performed efficiently on the device when + * applicable. \p hEvent may be from a different context or device than \p + * hStream. + * + * flags include: + * - ::CU_EVENT_WAIT_DEFAULT: Default event creation flag. + * - ::CU_EVENT_WAIT_EXTERNAL: Event is captured in the graph as an external + * event node when performing stream capture. This flag is invalid outside + * of stream capture. + * + * \param hStream - Stream to wait + * \param hEvent - Event to wait on (may not be NULL) + * \param Flags - See ::CUevent_capture_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuEventRecord, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cuStreamDestroy, + * ::cudaStreamWaitEvent + */ +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, + unsigned int Flags); + +/** + * \brief Add a callback to a compute stream + * + * \note This function is slated for eventual deprecation and removal. If + * you do not require the callback to execute in case of a device error, + * consider using ::cuLaunchHostFunc. Additionally, this function is not + * supported with ::cuStreamBeginCapture and ::cuStreamEndCapture, unlike + * ::cuLaunchHostFunc. + * + * Adds a callback to be called on the host after all currently enqueued + * items in the stream have completed. For each + * cuStreamAddCallback call, the callback will be executed exactly once. + * The callback will block later work in the stream until it is finished. + * + * The callback may be passed ::CUDA_SUCCESS or an error code. In the event + * of a device error, all subsequently executed callbacks will receive an + * appropriate ::CUresult. + * + * Callbacks must not make any CUDA API calls. Attempting to use a CUDA API + * will result in ::CUDA_ERROR_NOT_PERMITTED. Callbacks must not perform any + * synchronization that may depend on outstanding device work or other callbacks + * that are not mandated to run earlier. Callbacks without a mandated order + * (in independent streams) execute in undefined order and may be serialized. + * + * For the purposes of Unified Memory, callback execution makes a number of + * guarantees: + *
    + *
  • The callback stream is considered idle for the duration of the + * callback. Thus, for example, a callback may always use memory attached + * to the callback stream.
  • + *
  • The start of execution of a callback has the same effect as + * synchronizing an event recorded in the same stream immediately prior to + * the callback. It thus synchronizes streams which have been "joined" + * prior to the callback.
  • + *
  • Adding device work to any stream does not have the effect of making + * the stream active until all preceding host functions and stream callbacks + * have executed. Thus, for + * example, a callback might use global attached memory even if work has + * been added to another stream, if the work has been ordered behind the + * callback with an event.
  • + *
  • Completion of a callback does not cause a stream to become + * active except as described above. The callback stream will remain idle + * if no device work follows the callback, and will remain idle across + * consecutive callbacks without device work in between. Thus, for example, + * stream synchronization can be done by signaling from a callback at the + * end of the stream.
  • + *
+ * + * \param hStream - Stream to add callback to + * \param callback - The function to call once preceding stream operations are + * complete \param userData - User specified data to be passed to the callback + * function \param flags - Reserved for future use, must be 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuMemAllocManaged, + * ::cuStreamAttachMemAsync, + * ::cuLaunchHostFunc, + * ::cudaStreamAddCallback + */ +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, + CUstreamCallback callback, void *userData, + unsigned int flags); + +/** + * \brief Begins graph capture on a stream + * + * Begin graph capture on \p hStream. When a stream is in capture mode, all + * operations pushed into the stream will not be executed, but will instead be + * captured into a graph, which will be returned via ::cuStreamEndCapture. + * Capture may not be initiated if \p stream is CU_STREAM_LEGACY. Capture must + * be ended on the same stream in which it was initiated, and it may only be + * initiated if the stream is not already in capture mode. The capture mode may + * be queried via ::cuStreamIsCapturing. A unique id representing the capture + * sequence may be queried via ::cuStreamGetCaptureInfo. + * + * If \p mode is not ::CU_STREAM_CAPTURE_MODE_RELAXED, ::cuStreamEndCapture must + * be called on this stream from the same thread. + * + * \param hStream - Stream in which to initiate capture + * \param mode - Controls the interaction of this capture sequence with other + * API calls that are potentially unsafe. For more details see + * ::cuThreadExchangeStreamCaptureMode. + * + * \note Kernels captured using this API must not use texture and surface + * references. Reading or writing through any texture or surface reference is + * undefined behavior. This restriction does not apply to texture and surface + * objects. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuStreamCreate, + * ::cuStreamIsCapturing, + * ::cuStreamEndCapture, + * ::cuThreadExchangeStreamCaptureMode + */ +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream, + CUstreamCaptureMode mode); + +/** + * \brief Begins graph capture on a stream to an existing graph + * + * Begin graph capture on \p hStream, placing new nodes into an existing graph. + * When a stream is in capture mode, all operations pushed into the stream will + * not be executed, but will instead be captured into \p hGraph. The graph will + * not be instantiable until the user calls + * ::cuStreamEndCapture. + * + * Capture may not be initiated if \p stream is CU_STREAM_LEGACY. Capture must + * be ended on the same stream in which it was initiated, and it may only be + * initiated if the stream is not already in capture mode. The capture mode may + * be queried via ::cuStreamIsCapturing. A unique id representing the capture + * sequence may be queried via ::cuStreamGetCaptureInfo. + * + * If \p mode is not ::CU_STREAM_CAPTURE_MODE_RELAXED, ::cuStreamEndCapture must + * be called on this stream from the same thread. + * + * \param hStream - Stream in which to initiate capture. + * \param hGraph - Graph to capture into. + * \param dependencies - Dependencies of the first node captured in the + * stream. Can be NULL if numDependencies is 0. \param dependencyData - + * Optional array of data associated with each dependency. \param + * numDependencies - Number of dependencies. \param mode - Controls + * the interaction of this capture sequence with other API calls that are + * potentially unsafe. For more details see + * ::cuThreadExchangeStreamCaptureMode. + * + * \note Kernels captured using this API must not use texture and surface + * references. Reading or writing through any texture or surface reference is + * undefined behavior. This restriction does not apply to texture and surface + * objects. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuStreamBeginCapture, + * ::cuStreamCreate, + * ::cuStreamIsCapturing, + * ::cuStreamEndCapture, + * ::cuThreadExchangeStreamCaptureMode, + * ::cuGraphAddNode, + */ +CUresult CUDAAPI cuStreamBeginCaptureToGraph( + CUstream hStream, CUgraph hGraph, const CUgraphNode *dependencies, + const CUgraphEdgeData *dependencyData, size_t numDependencies, + CUstreamCaptureMode mode); + +/** + * \brief Swaps the stream capture interaction mode for a thread + * + * Sets the calling thread's stream capture interaction mode to the value + contained + * in \p *mode, and overwrites \p *mode with the previous mode for the thread. + To + * facilitate deterministic behavior across function or module boundaries, + callers + * are encouraged to use this API in a push-pop fashion: \code + CUstreamCaptureMode mode = desiredMode; + cuThreadExchangeStreamCaptureMode(&mode); + ... + cuThreadExchangeStreamCaptureMode(&mode); // restore previous mode + * \endcode + * + * During stream capture (see ::cuStreamBeginCapture), some actions, such as a + call + * to ::cudaMalloc, may be unsafe. In the case of ::cudaMalloc, the operation is + * not enqueued asynchronously to a stream, and is not observed by stream + capture. + * Therefore, if the sequence of operations captured via ::cuStreamBeginCapture + * depended on the allocation being replayed whenever the graph is launched, the + * captured graph would be invalid. + * + * Therefore, stream capture places restrictions on API calls that can be made + within + * or concurrently to a ::cuStreamBeginCapture-::cuStreamEndCapture sequence. + This + * behavior can be controlled via this API and flags to ::cuStreamBeginCapture. + * + * A thread's mode is one of the following: + * - \p CU_STREAM_CAPTURE_MODE_GLOBAL: This is the default mode. If the local + thread has + * an ongoing capture sequence that was not initiated with + * \p CU_STREAM_CAPTURE_MODE_RELAXED at \p cuStreamBeginCapture, or if any + other thread + * has a concurrent capture sequence initiated with \p + CU_STREAM_CAPTURE_MODE_GLOBAL, + * this thread is prohibited from potentially unsafe API calls. + * - \p CU_STREAM_CAPTURE_MODE_THREAD_LOCAL: If the local thread has an ongoing + capture + * sequence not initiated with \p CU_STREAM_CAPTURE_MODE_RELAXED, it is + prohibited + * from potentially unsafe API calls. Concurrent capture sequences in other + threads + * are ignored. + * - \p CU_STREAM_CAPTURE_MODE_RELAXED: The local thread is not prohibited from + potentially + * unsafe API calls. Note that the thread is still prohibited from API calls + which + * necessarily conflict with stream capture, for example, attempting + ::cuEventQuery + * on an event that was last recorded inside a capture sequence. + * + * \param mode - Pointer to mode value to swap with the current mode + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuStreamBeginCapture + */ +CUresult CUDAAPI cuThreadExchangeStreamCaptureMode(CUstreamCaptureMode *mode); + +/** + * \brief Ends capture on a stream, returning the captured graph + * + * End capture on \p hStream, returning the captured graph via \p phGraph. + * Capture must have been initiated on \p hStream via a call to + * ::cuStreamBeginCapture. If capture was invalidated, due to a violation of the + * rules of stream capture, then a NULL graph will be returned. + * + * If the \p mode argument to ::cuStreamBeginCapture was not + * ::CU_STREAM_CAPTURE_MODE_RELAXED, this call must be from the same thread as + * ::cuStreamBeginCapture. + * + * \param hStream - Stream to query + * \param phGraph - The captured graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD + * \notefnerr + * + * \sa + * ::cuStreamCreate, + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing, + * ::cuGraphDestroy + */ +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); + +/** + * \brief Returns a stream's capture status + * + * Return the capture status of \p hStream via \p captureStatus. After a + * successful call, \p *captureStatus will contain one of the following: + * - ::CU_STREAM_CAPTURE_STATUS_NONE: The stream is not capturing. + * - ::CU_STREAM_CAPTURE_STATUS_ACTIVE: The stream is capturing. + * - ::CU_STREAM_CAPTURE_STATUS_INVALIDATED: The stream was capturing but an + * error has invalidated the capture sequence. The capture sequence must be + * terminated with ::cuStreamEndCapture on the stream where it was initiated in + * order to continue using \p hStream. + * + * Note that, if this is called on ::CU_STREAM_LEGACY (the "null stream") while + * a blocking stream in the same context is capturing, it will return + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT and \p *captureStatus is unspecified + * after the call. The blocking stream capture is not invalidated. + * + * When a blocking stream is capturing, the legacy stream is in an + * unusable state until the blocking stream capture is terminated. The legacy + * stream is not supported for stream capture, but attempted use would have an + * implicit dependency on the capturing stream(s). + * + * \param hStream - Stream to query + * \param captureStatus - Returns the stream's capture status + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \notefnerr + * + * \sa + * ::cuStreamCreate, + * ::cuStreamBeginCapture, + * ::cuStreamEndCapture + */ +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, + CUstreamCaptureStatus *captureStatus); + +/** + * \brief Query a stream's capture state + * + * Query stream state related to stream capture. + * + * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not + * created with ::CU_STREAM_NON_BLOCKING is capturing, returns + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. + * + * Valid data (other than capture status) is returned only if both of the + * following are true: + * - the call returns CUDA_SUCCESS + * - the returned capture status is ::CU_STREAM_CAPTURE_STATUS_ACTIVE + * + * \param hStream - The stream to query + * \param captureStatus_out - Location to return the capture status of the + * stream; required \param id_out - Optional location to return an id for the + * capture sequence, which is unique over the lifetime of the process \param + * graph_out - Optional location to return the graph being captured into. All + * operations other than destroy and node removal are permitted on the + * graph while the capture sequence is in progress. This API does not transfer + * ownership of the graph, which is transferred or destroyed at + * ::cuStreamEndCapture. Note that the graph handle may be invalidated + * before end of capture for certain errors. Nodes that are or become + * unreachable from the original stream at ::cuStreamEndCapture due to + * direct actions on the graph do not trigger + * ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED. \param dependencies_out - Optional + * location to store a pointer to an array of nodes. The next node to be + * captured in the stream will depend on this set of nodes, absent operations + * such as event wait which modify this set. The array pointer is valid until + * the next API call which operates on the stream or until the capture is + * terminated. The node handles may be copied out and are valid until they or + * the graph is destroyed. The driver-owned array may also be passed directly to + * APIs that operate on the graph (not the stream) without copying. \param + * numDependencies_out - Optional location to store the size of the array + * returned in dependencies_out. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamGetCaptureInfo_v3 + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing, + * ::cuStreamUpdateCaptureDependencies + */ +CUresult CUDAAPI cuStreamGetCaptureInfo( + CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out, CUgraph *graph_out, + const CUgraphNode **dependencies_out, size_t *numDependencies_out); + +/** + * \brief Query a stream's capture state (12.3+) + * + * Query stream state related to stream capture. + * + * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not + * created with ::CU_STREAM_NON_BLOCKING is capturing, returns + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. + * + * Valid data (other than capture status) is returned only if both of the + * following are true: + * - the call returns CUDA_SUCCESS + * - the returned capture status is ::CU_STREAM_CAPTURE_STATUS_ACTIVE + * + * If \p edgeData_out is non-NULL then \p dependencies_out must be as well. If + * \p dependencies_out is non-NULL and \p edgeData_out is NULL, but there is + * non-zero edge data for one or more of the current stream dependencies, the + * call will return + * ::CUDA_ERROR_LOSSY_QUERY. + * + * \param hStream - The stream to query + * \param captureStatus_out - Location to return the capture status of the + * stream; required \param id_out - Optional location to return an id for the + * capture sequence, which is unique over the lifetime of the process \param + * graph_out - Optional location to return the graph being captured into. All + * operations other than destroy and node removal are permitted on the + * graph while the capture sequence is in progress. This API does not transfer + * ownership of the graph, which is transferred or destroyed at + * ::cuStreamEndCapture. Note that the graph handle may be invalidated + * before end of capture for certain errors. Nodes that are or become + * unreachable from the original stream at ::cuStreamEndCapture due to + * direct actions on the graph do not trigger + * ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED. \param dependencies_out - Optional + * location to store a pointer to an array of nodes. The next node to be + * captured in the stream will depend on this set of nodes, absent operations + * such as event wait which modify this set. The array pointer is valid until + * the next API call which operates on the stream or until the capture is + * terminated. The node handles may be copied out and are valid until they or + * the graph is destroyed. The driver-owned array may also be passed directly to + * APIs that operate on the graph (not the stream) without copying. \param + * edgeData_out - Optional location to store a pointer to an array of graph edge + * data. This array parallels \c dependencies_out; the next node to be + * added has an edge to \c dependencies_out[i] with annotation \c + * edgeData_out[i] for each \c i. The array pointer is valid until the next API + * call which operates on the stream or until the capture is terminated. \param + * numDependencies_out - Optional location to store the size of the array + * returned in dependencies_out. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT, + * ::CUDA_ERROR_LOSSY_QUERY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamGetCaptureInfo + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing, + * ::cuStreamUpdateCaptureDependencies + */ +CUresult CUDAAPI cuStreamGetCaptureInfo_v3( + CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out, CUgraph *graph_out, + const CUgraphNode **dependencies_out, const CUgraphEdgeData **edgeData_out, + size_t *numDependencies_out); + +/** + * \brief Update the set of dependencies in a capturing stream (11.3+) + * + * Modifies the dependency set of a capturing stream. The dependency set is the + * set of nodes that the next captured node in the stream will depend on. + * + * Valid flags are ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES and + * ::CU_STREAM_SET_CAPTURE_DEPENDENCIES. These control whether the set passed to + * the API is added to the existing set or replaces it. A flags value of 0 + * defaults to ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES. + * + * Nodes that are removed from the dependency set via this API do not result in + * ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED if they are unreachable from the stream + * at + * ::cuStreamEndCapture. + * + * Returns ::CUDA_ERROR_ILLEGAL_STATE if the stream is not capturing. + * + * This API is new in CUDA 11.3. Developers requiring compatibility across minor + * versions to CUDA 11.0 should not use this API or provide a fallback. + * + * \param hStream - The stream to update + * \param dependencies - The set of dependencies to add + * \param numDependencies - The size of the dependencies array + * \param flags - See above + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_ILLEGAL_STATE + * + * \sa + * ::cuStreamBeginCapture, + * ::cuStreamGetCaptureInfo, + */ +CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, + CUgraphNode *dependencies, + size_t numDependencies, + unsigned int flags); + +/** + * \brief Update the set of dependencies in a capturing stream (12.3+) + * + * Modifies the dependency set of a capturing stream. The dependency set is the + * set of nodes that the next captured node in the stream will depend on along + * with the edge data for those dependencies. + * + * Valid flags are ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES and + * ::CU_STREAM_SET_CAPTURE_DEPENDENCIES. These control whether the set passed to + * the API is added to the existing set or replaces it. A flags value of 0 + * defaults to ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES. + * + * Nodes that are removed from the dependency set via this API do not result in + * ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED if they are unreachable from the stream + * at + * ::cuStreamEndCapture. + * + * Returns ::CUDA_ERROR_ILLEGAL_STATE if the stream is not capturing. + * + * \param hStream - The stream to update + * \param dependencies - The set of dependencies to add + * \param dependencyData - Optional array of data associated with each + * dependency. \param numDependencies - The size of the dependencies array + * \param flags - See above + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_ILLEGAL_STATE + * + * \sa + * ::cuStreamBeginCapture, + * ::cuStreamGetCaptureInfo, + */ +CUresult CUDAAPI cuStreamUpdateCaptureDependencies_v2( + CUstream hStream, CUgraphNode *dependencies, + const CUgraphEdgeData *dependencyData, size_t numDependencies, + unsigned int flags); + +/** + * \brief Attach memory to a stream asynchronously + * + * Enqueues an operation in \p hStream to specify stream association of + * \p length bytes of memory starting from \p dptr. This function is a + * stream-ordered operation, meaning that it is dependent on, and will + * only take effect when, previous work in stream has completed. Any + * previous association is automatically replaced. + * + * \p dptr must point to one of the following types of memories: + * - managed memory declared using the __managed__ keyword or allocated with + * ::cuMemAllocManaged. + * - a valid host-accessible region of system-allocated pageable memory. This + * type of memory may only be specified if the device associated with the + * stream reports a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. + * + * For managed allocations, \p length must be either zero or the entire + * allocation's size. Both indicate that the entire allocation's stream + * association is being changed. Currently, it is not possible to change stream + * association for a portion of a managed allocation. + * + * For pageable host allocations, \p length must be non-zero. + * + * The stream association is specified using \p flags which must be + * one of ::CUmemAttach_flags. + * If the ::CU_MEM_ATTACH_GLOBAL flag is specified, the memory can be accessed + * by any stream on any device. + * If the ::CU_MEM_ATTACH_HOST flag is specified, the program makes a guarantee + * that it won't access the memory on the device from any stream on a device + * that has a zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. If the + * ::CU_MEM_ATTACH_SINGLE flag is specified and \p hStream is associated with a + * device that has a zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, the program makes a + * guarantee that it will only access the memory on the device from \p hStream. + * It is illegal to attach singly to the NULL stream, because the NULL stream is + * a virtual global stream and not a specific stream. An error will be returned + * in this case. + * + * When memory is associated with a single stream, the Unified Memory system + * will allow CPU access to this memory region so long as all operations in \p + * hStream have completed, regardless of whether other streams are active. In + * effect, this constrains exclusive ownership of the managed memory region by + * an active GPU to per-stream activity instead of whole-GPU activity. + * + * Accessing memory on the device from streams that are not associated with + * it will produce undefined results. No error checking is performed by the + * Unified Memory system to ensure that kernels launched into other streams + * do not access this region. + * + * It is a program's responsibility to order calls to ::cuStreamAttachMemAsync + * via events, synchronization or other means to ensure legal access to memory + * at all times. Data visibility and coherency will be changed appropriately + * for all kernels which follow a stream-association change. + * + * If \p hStream is destroyed while data is associated with it, the association + * is removed and the association reverts to the default visibility of the + * allocation as specified at ::cuMemAllocManaged. For __managed__ variables, + * the default association is always ::CU_MEM_ATTACH_GLOBAL. Note that + * destroying a stream is an asynchronous operation, and as a result, the change + * to default association won't happen until all work in the stream has + * completed. + * + * \param hStream - Stream in which to enqueue the attach operation + * \param dptr - Pointer to memory (must be a pointer to managed memory or + * to a valid host-accessible region of system-allocated + * pageable memory) + * \param length - Length of memory + * \param flags - Must be one of ::CUmemAttach_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuMemAllocManaged, + * ::cudaStreamAttachMemAsync + */ +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, + size_t length, unsigned int flags); + +/** + * \brief Determine status of a compute stream + * + * Returns ::CUDA_SUCCESS if all operations in the stream specified by + * \p hStream have completed, or ::CUDA_ERROR_NOT_READY if not. + * + * For the purposes of Unified Memory, a return value of ::CUDA_SUCCESS + * is equivalent to having called ::cuStreamSynchronize(). + * + * \param hStream - Stream to query status of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_READY + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamQuery + */ +CUresult CUDAAPI cuStreamQuery(CUstream hStream); + +/** + * \brief Wait until a stream's tasks are completed + * + * Waits until the device has completed all operations in the stream specified + * by \p hStream. If the context was created with the + * ::CU_CTX_SCHED_BLOCKING_SYNC flag, the CPU thread will block until the + * stream is finished with all of its tasks. + * + * \param hStream - Stream to wait for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamDestroy, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamAddCallback, + * ::cudaStreamSynchronize + */ +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); + +/** + * \brief Destroys a stream + * + * Destroys the stream specified by \p hStream. + * + * In case the device is still doing work in the stream \p hStream + * when ::cuStreamDestroy() is called, the function will return immediately + * and the resources associated with \p hStream will be released automatically + * once the device has completed all work in \p hStream. + * + * \param hStream - Stream to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamDestroy + */ +CUresult CUDAAPI cuStreamDestroy(CUstream hStream); + +/** + * \brief Copies attributes from source stream to destination stream. + * + * Copies attributes from source stream \p src to destination stream \p dst. + * Both streams must have the same context. + * + * \param[out] dst Destination stream + * \param[in] src Source stream + * For list of attributes see ::CUstreamAttrID + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamCopyAttributes(CUstream dst, CUstream src); + +/** + * \brief Queries stream attribute. + * + * Queries attribute \p attr from \p hStream and stores it in corresponding + * member of \p value_out. + * + * \param[in] hStream + * \param[in] attr + * \param[out] value_out + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, + CUstreamAttrValue *value_out); + +/** + * \brief Sets stream attribute. + * + * Sets attribute \p attr on \p hStream from corresponding attribute of + * \p value. The updated attribute will be applied to subsequent work + * submitted to the stream. It will not affect previously submitted work. + * + * \param[out] hStream + * \param[in] attr + * \param[in] value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, + const CUstreamAttrValue *value); + +/** @} */ /* END CUDA_STREAM */ + +/** + * \defgroup CUDA_EVENT Event Management + * + * ___MANBRIEF___ event management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the event management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Creates an event + * + * Creates an event *phEvent for the current context with the flags specified + * via \p Flags. Valid flags include: + * - ::CU_EVENT_DEFAULT: Default event creation flag. + * - ::CU_EVENT_BLOCKING_SYNC: Specifies that the created event should use + * blocking synchronization. A CPU thread that uses ::cuEventSynchronize() to + * wait on an event created with this flag will block until the event has + * actually been recorded. + * - ::CU_EVENT_DISABLE_TIMING: Specifies that the created event does not need + * to record timing data. Events created with this flag specified and + * the ::CU_EVENT_BLOCKING_SYNC flag not specified will provide the best + * performance when used with ::cuStreamWaitEvent() and ::cuEventQuery(). + * - ::CU_EVENT_INTERPROCESS: Specifies that the created event may be used as an + * interprocess event by ::cuIpcGetEventHandle(). ::CU_EVENT_INTERPROCESS must + * be specified along with ::CU_EVENT_DISABLE_TIMING. + * + * \param phEvent - Returns newly created event + * \param Flags - Event creation flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventCreate, + * ::cudaEventCreateWithFlags + */ +CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags); + +/** + * \brief Records an event + * + * Captures in \p hEvent the contents of \p hStream at the time of this call. + * \p hEvent and \p hStream must be from the same context. + * Calls such as ::cuEventQuery() or ::cuStreamWaitEvent() will then + * examine or wait for completion of the work that was captured. Uses of + * \p hStream after this call do not modify \p hEvent. See note on default + * stream behavior for what is captured in the default case. + * + * ::cuEventRecord() can be called multiple times on the same event and + * will overwrite the previously captured state. Other APIs such as + * ::cuStreamWaitEvent() use the most recently captured state at the time + * of the API call, and are not affected by later calls to + * ::cuEventRecord(). Before the first call to ::cuEventRecord(), an + * event represents an empty set of work, so for example ::cuEventQuery() + * would return ::CUDA_SUCCESS. + * + * \param hEvent - Event to record + * \param hStream - Stream to record event for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \note_null_stream + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuStreamWaitEvent, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventRecord, + * ::cuEventRecordWithFlags + */ +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); + +/** + * \brief Records an event + * + * Captures in \p hEvent the contents of \p hStream at the time of this call. + * \p hEvent and \p hStream must be from the same context. + * Calls such as ::cuEventQuery() or ::cuStreamWaitEvent() will then + * examine or wait for completion of the work that was captured. Uses of + * \p hStream after this call do not modify \p hEvent. See note on default + * stream behavior for what is captured in the default case. + * + * ::cuEventRecordWithFlags() can be called multiple times on the same event and + * will overwrite the previously captured state. Other APIs such as + * ::cuStreamWaitEvent() use the most recently captured state at the time + * of the API call, and are not affected by later calls to + * ::cuEventRecordWithFlags(). Before the first call to + * ::cuEventRecordWithFlags(), an event represents an empty set of work, so for + * example ::cuEventQuery() would return ::CUDA_SUCCESS. + * + * flags include: + * - ::CU_EVENT_RECORD_DEFAULT: Default event creation flag. + * - ::CU_EVENT_RECORD_EXTERNAL: Event is captured in the graph as an external + * event node when performing stream capture. This flag is invalid outside + * of stream capture. + * + * \param hEvent - Event to record + * \param hStream - Stream to record event for + * \param flags - See ::CUevent_capture_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \note_null_stream + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuStreamWaitEvent, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cuEventRecord, + * ::cudaEventRecord + */ +CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, + unsigned int flags); + +/** + * \brief Queries an event's status + * + * Queries the status of all work currently captured by \p hEvent. See + * ::cuEventRecord() for details on what is captured by an event. + * + * Returns ::CUDA_SUCCESS if all captured work has been completed, or + * ::CUDA_ERROR_NOT_READY if any captured work is incomplete. + * + * For the purposes of Unified Memory, a return value of ::CUDA_SUCCESS + * is equivalent to having called ::cuEventSynchronize(). + * + * \param hEvent - Event to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_READY + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventSynchronize, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventQuery + */ +CUresult CUDAAPI cuEventQuery(CUevent hEvent); + +/** + * \brief Waits for an event to complete + * + * Waits until the completion of all work currently captured in \p hEvent. + * See ::cuEventRecord() for details on what is captured by an event. + * + * Waiting for an event that was created with the ::CU_EVENT_BLOCKING_SYNC + * flag will cause the calling CPU thread to block until the event has + * been completed by the device. If the ::CU_EVENT_BLOCKING_SYNC flag has + * not been set, then the CPU thread will busy-wait until the event has + * been completed by the device. + * + * \param hEvent - Event to wait for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventSynchronize + */ +CUresult CUDAAPI cuEventSynchronize(CUevent hEvent); + +/** + * \brief Destroys an event + * + * Destroys the event specified by \p hEvent. + * + * An event may be destroyed before it is complete (i.e., while + * ::cuEventQuery() would return ::CUDA_ERROR_NOT_READY). In this case, the + * call does not block on completion of the event, and any associated + * resources will automatically be released asynchronously at completion. + * + * \param hEvent - Event to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuEventElapsedTime, + * ::cudaEventDestroy + */ +CUresult CUDAAPI cuEventDestroy(CUevent hEvent); + +/** + * \brief Computes the elapsed time between two events + * + * Computes the elapsed time between two events (in milliseconds with a + * resolution of around 0.5 microseconds). + * + * If either event was last recorded in a non-NULL stream, the resulting time + * may be greater than expected (even if both used the same stream handle). This + * happens because the ::cuEventRecord() operation takes place asynchronously + * and there is no guarantee that the measured latency is actually just between + * the two events. Any number of other different stream operations could execute + * in between the two measured events, thus altering the timing in a significant + * way. + * + * If ::cuEventRecord() has not been called on either event then + * ::CUDA_ERROR_INVALID_HANDLE is returned. If ::cuEventRecord() has been called + * on both events but one or both of them has not yet been completed (that is, + * ::cuEventQuery() would return ::CUDA_ERROR_NOT_READY on at least one of the + * events), ::CUDA_ERROR_NOT_READY is returned. If either event was created with + * the ::CU_EVENT_DISABLE_TIMING flag, then this function will return + * ::CUDA_ERROR_INVALID_HANDLE. + * + * \param pMilliseconds - Time between \p hStart and \p hEnd in ms + * \param hStart - Starting event + * \param hEnd - Ending event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_READY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuEventDestroy, + * ::cudaEventElapsedTime + */ +CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, + CUevent hEnd); + +/** @} */ /* END CUDA_EVENT */ + +/** + * \defgroup CUDA_EXTRES_INTEROP External Resource Interoperability + * + * ___MANBRIEF___ External resource interoperability functions of the low-level + * CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the external resource interoperability functions of + * the low-level CUDA driver application programming interface. + * + * @{ + */ + +/** +* \brief Imports an external memory object +* +* Imports an externally allocated memory object and returns +* a handle to that in \p extMem_out. +* +* The properties of the handle being imported must be described in +* \p memHandleDesc. The ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC structure +* is defined as follows: +* +* \code + typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { + CUexternalMemoryHandleType type; + union { + int fd; + struct { + void *handle; + const void *name; + } win32; + const void *nvSciBufObject; + } handle; + unsigned long long size; + unsigned int flags; + } CUDA_EXTERNAL_MEMORY_HANDLE_DESC; +* \endcode +* +* where ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type specifies the type +* of handle being imported. ::CUexternalMemoryHandleType is +* defined as: +* +* \code + typedef enum CUexternalMemoryHandleType_enum { + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE = 6, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT = 7, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF = 8 + } CUexternalMemoryHandleType; +* \endcode +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD, then +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::fd must be a valid +* file descriptor referencing a memory object. Ownership of +* the file descriptor is transferred to the CUDA driver when the +* handle is imported successfully. Performing any operations on the +* file descriptor after it is imported results in undefined behavior. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32, then exactly one +* of ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle and +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be +* NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle +* is not NULL, then it must represent a valid shared NT handle that +* references a memory object. Ownership of this handle is +* not transferred to CUDA after the import operation, so the +* application must release the handle using the appropriate system +* call. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name +* is not NULL, then it must point to a NULL-terminated array of +* UTF-16 characters that refers to a memory object. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT, then +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must +* be non-NULL and +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name +* must be NULL. The handle specified must be a globally shared KMT +* handle. This handle does not hold a reference to the underlying +* object, and thus will be invalid when all references to the +* memory object are destroyed. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, then exactly one +* of ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle and +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be +* NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle +* is not NULL, then it must represent a valid shared NT handle that +* is returned by ID3D12Device::CreateSharedHandle when referring to a +* ID3D12Heap object. This handle holds a reference to the underlying +* object. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name +* is not NULL, then it must point to a NULL-terminated array of +* UTF-16 characters that refers to a ID3D12Heap object. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, then exactly one +* of ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle and +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be +* NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle +* is not NULL, then it must represent a valid shared NT handle that +* is returned by ID3D12Device::CreateSharedHandle when referring to a +* ID3D12Resource object. This handle holds a reference to the +* underlying object. If +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name +* is not NULL, then it must point to a NULL-terminated array of +* UTF-16 characters that refers to a ID3D12Resource object. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE, then +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must +* represent a valid shared NT handle that is returned by +* IDXGIResource1::CreateSharedHandle when referring to a +* ID3D11Resource object. If +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name +* is not NULL, then it must point to a NULL-terminated array of +* UTF-16 characters that refers to a ID3D11Resource object. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT, then +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must +* represent a valid shared KMT handle that is returned by +* IDXGIResource::GetSharedHandle when referring to a +* ID3D11Resource object and +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name +* must be NULL. +* +* If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, then +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::nvSciBufObject must be non-NULL +* and reference a valid NvSciBuf object. +* If the NvSciBuf object imported into CUDA is also mapped by other drivers, +then the +* application must use ::cuWaitExternalSemaphoresAsync or +::cuSignalExternalSemaphoresAsync +* as appropriate barriers to maintain coherence between CUDA and the other +drivers. +* See ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC and +::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC +* for memory synchronization. +* +* +* The size of the memory object must be specified in +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::size. +* +* Specifying the flag ::CUDA_EXTERNAL_MEMORY_DEDICATED in +* ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::flags indicates that the +* resource is a dedicated resource. The definition of what a +* dedicated resource is outside the scope of this extension. +* This flag must be set if ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type +* is one of the following: +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE +* ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT +* +* \param extMem_out - Returned handle to an external memory object +* \param memHandleDesc - Memory import handle descriptor +* +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_HANDLE, +* ::CUDA_ERROR_OPERATING_SYSTEM +* \notefnerr +* +* \note If the Vulkan memory imported into CUDA is mapped on the CPU then the +* application must use vkInvalidateMappedMemoryRanges/vkFlushMappedMemoryRanges +* as well as appropriate Vulkan pipeline barriers to maintain coherence between +* CPU and GPU. For more information on these APIs, please refer to +"Synchronization +* and Cache Control" chapter from Vulkan specification. +* +* \sa ::cuDestroyExternalMemory, +* ::cuExternalMemoryGetMappedBuffer, +* ::cuExternalMemoryGetMappedMipmappedArray +*/ +CUresult CUDAAPI +cuImportExternalMemory(CUexternalMemory *extMem_out, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *memHandleDesc); + +/** + * \brief Maps a buffer onto an imported memory object + * + * Maps a buffer onto an imported memory object and returns a device + * pointer in \p devPtr. + * + * The properties of the buffer being mapped must be described in + * \p bufferDesc. The ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC structure is + * defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_MEMORY_BUFFER_DESC_st { + unsigned long long offset; + unsigned long long size; + unsigned int flags; + } CUDA_EXTERNAL_MEMORY_BUFFER_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC::offset is the offset in + * the memory object where the buffer's base address is. + * ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC::size is the size of the buffer. + * ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC::flags must be zero. + * + * The offset and size have to be suitably aligned to match the + * requirements of the external API. Mapping two buffers whose ranges + * overlap may or may not result in the same virtual address being + * returned for the overlapped portion. In such cases, the application + * must ensure that all accesses to that region from the GPU are + * volatile. Otherwise writes made via one address are not guaranteed + * to be visible via the other address, even if they're issued by the + * same thread. It is recommended that applications map the combined + * range instead of mapping separate buffers and then apply the + * appropriate offsets to the returned pointer to derive the + * individual buffers. + * + * The returned pointer \p devPtr must be freed using ::cuMemFree. + * + * \param devPtr - Returned device pointer to buffer + * \param extMem - Handle to external memory object + * \param bufferDesc - Buffer descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalMemory, + * ::cuDestroyExternalMemory, + * ::cuExternalMemoryGetMappedMipmappedArray + */ +CUresult CUDAAPI cuExternalMemoryGetMappedBuffer( + CUdeviceptr *devPtr, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *bufferDesc); + +/** + * \brief Maps a CUDA mipmapped array onto an external memory object + * + * Maps a CUDA mipmapped array onto an external object and returns a + * handle to it in \p mipmap. + * + * The properties of the CUDA mipmapped array being mapped must be + * described in \p mipmapDesc. The structure + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC is defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_st { + unsigned long long offset; + CUDA_ARRAY3D_DESCRIPTOR arrayDesc; + unsigned int numLevels; + } CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::offset is the + * offset in the memory object where the base level of the mipmap + * chain is. + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::arrayDesc describes + * the format, dimensions and type of the base level of the mipmap + * chain. For further details on these parameters, please refer to the + * documentation for ::cuMipmappedArrayCreate. Note that if the mipmapped + * array is bound as a color target in the graphics API, then the flag + * ::CUDA_ARRAY3D_COLOR_ATTACHMENT must be specified in + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::arrayDesc::Flags. + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::numLevels specifies + * the total number of levels in the mipmap chain. + * + * If \p extMem was imported from a handle of type + ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, then + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::numLevels must be equal to 1. + * + * The returned CUDA mipmapped array must be freed using + ::cuMipmappedArrayDestroy. + * + * \param mipmap - Returned CUDA mipmapped array + * \param extMem - Handle to external memory object + * \param mipmapDesc - CUDA array descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalMemory, + * ::cuDestroyExternalMemory, + * ::cuExternalMemoryGetMappedBuffer + */ +CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray( + CUmipmappedArray *mipmap, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *mipmapDesc); + +/** + * \brief Destroys an external memory object. + * + * Destroys the specified external memory object. Any existing buffers + * and CUDA mipmapped arrays mapped onto this object must no longer be + * used and must be explicitly freed using ::cuMemFree and + * ::cuMipmappedArrayDestroy respectively. + * + * \param extMem - External memory object to be destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalMemory, + * ::cuExternalMemoryGetMappedBuffer, + * ::cuExternalMemoryGetMappedMipmappedArray + */ +CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); + +/** + * \brief Imports an external semaphore + * + * Imports an externally allocated synchronization object and returns + * a handle to that in \p extSem_out. + * + * The properties of the handle being imported must be described in + * \p semHandleDesc. The ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC is + * defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { + CUexternalSemaphoreHandleType type; + union { + int fd; + struct { + void *handle; + const void *name; + } win32; + const void* NvSciSyncObj; + } handle; + unsigned int flags; + } CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type specifies the type of + * handle being imported. ::CUexternalSemaphoreHandleType is defined + * as: + * + * \code + typedef enum CUexternalSemaphoreHandleType_enum { + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE = 5, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC = 6, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX = 7, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT = 8, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD = 9, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 = 10 + } CUexternalSemaphoreHandleType; + * \endcode + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::fd must be a valid + * file descriptor referencing a synchronization object. Ownership of + * the file descriptor is transferred to the CUDA driver when the + * handle is imported successfully. Performing any operations on the + * file descriptor after it is imported results in undefined behavior. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, then exactly one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * references a synchronization object. Ownership of this handle is + * not transferred to CUDA after the import operation, so the + * application must release the handle using the appropriate system + * call. If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle must + * be non-NULL and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * must be NULL. The handle specified must be a globally shared KMT + * handle. This handle does not hold a reference to the underlying + * object, and thus will be invalid when all references to the + * synchronization object are destroyed. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, then exactly one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * is returned by ID3D12Device::CreateSharedHandle when referring to a + * ID3D12Fence object. This handle holds a reference to the underlying + * object. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid ID3D12Fence object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared NT handle that is returned by + * ID3D11Fence::CreateSharedHandle. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid ID3D11Fence object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::nvSciSyncObj + * represents a valid NvSciSyncObj. + * + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared NT handle that + * is returned by IDXGIResource1::CreateSharedHandle when referring to + * a IDXGIKeyedMutex object. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid IDXGIKeyedMutex object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared KMT handle that + * is returned by IDXGIResource::GetSharedHandle when referring to + * a IDXGIKeyedMutex object and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must be NULL. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::fd must be a valid + * file descriptor referencing a synchronization object. Ownership of + * the file descriptor is transferred to the CUDA driver when the + * handle is imported successfully. Performing any operations on the + * file descriptor after it is imported results in undefined behavior. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32, then exactly + one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * references a synchronization object. Ownership of this handle is + * not transferred to CUDA after the import operation, so the + * application must release the handle using the appropriate system + * call. If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object. + * + * \param extSem_out - Returned handle to an external semaphore + * \param semHandleDesc - Semaphore import handle descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OPERATING_SYSTEM + * \notefnerr + * + * \sa ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuImportExternalSemaphore( + CUexternalSemaphore *extSem_out, + const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *semHandleDesc); + +/** + * \brief Signals a set of external semaphore objects + * + * Enqueues a signal operation on a set of externally allocated + * semaphore object in the specified stream. The operations will be + * executed when all prior operations in the stream complete. + * + * The exact semantics of signaling a semaphore depends on the type of + * the object. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * then signaling the semaphore will set it to the signaled state. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * then the semaphore will be set to the value specified in + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::fence::value. + * + * If the semaphore object is of the type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC this API sets + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence to a value + * that can be used by subsequent waiters of the same NvSciSync object to order + * operations with those currently submitted in \p stream. Such an update will + * overwrite previous contents of + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence. By + * default, signaling such an external semaphore object causes appropriate + * memory synchronization operations to be performed over all external memory + * objects that are imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. This ensures that any subsequent + * accesses made by other importers of the same set of NvSciBuf memory object(s) + * are coherent. These operations can be skipped by specifying the flag + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC, which can be used as + * a performance optimization when data coherency is not required. But + * specifying this flag in scenarios where data coherency is required results in + * undefined behavior. Also, for semaphore object of the type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, if the NvSciSyncAttrList used + * to create the NvSciSyncObj had not set the flags in + * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_SIGNAL, this API will + * return CUDA_ERROR_NOT_SUPPORTED. NvSciSyncFence associated with semaphore + * object of the type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC can be deterministic. For this + * the NvSciSyncAttrList used to create the semaphore object must have value of + * NvSciSyncAttrKey_RequireDeterministicFences key set to true. Deterministic + * fences allow users to enqueue a wait over the semaphore object even before + * corresponding signal is enqueued. For such a semaphore object, CUDA + * guarantees that each signal operation will increment the fence value by '1'. + * Users are expected to track count of signals enqueued on the semaphore object + * and insert waits accordingly. When such a semaphore object is signaled from + * multiple streams, due to concurrent stream execution, it is possible that the + * order in which the semaphore gets signaled is indeterministic. This could + * lead to waiters of the semaphore getting unblocked incorrectly. Users are + * expected to handle such situations, either by not using the same semaphore + * object with deterministic fence support enabled in different streams or by + * adding explicit dependency amongst such streams so that the semaphore is + * signaled in order. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then the keyed mutex will be released with the key specified in + * ::CUDA_EXTERNAL_SEMAPHORE_PARAMS::params::keyedmutex::key. + * + * \param extSemArray - Set of external semaphores to be signaled + * \param paramsArray - Array of semaphore parameters + * \param numExtSems - Number of semaphores to signal + * \param stream - Stream to enqueue the signal operations in + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuSignalExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream); + +/** + * \brief Waits on a set of external semaphore objects + * + * Enqueues a wait operation on a set of externally allocated + * semaphore object in the specified stream. The operations will be + * executed when all prior operations in the stream complete. + * + * The exact semantics of waiting on a semaphore depends on the type + * of the object. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * then waiting on the semaphore will wait until the semaphore reaches + * the signaled state. The semaphore will then be reset to the + * unsignaled state. Therefore for every signal operation, there can + * only be one wait operation. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * then waiting on the semaphore will wait until the value of the + * semaphore is greater than or equal to + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::fence::value. + * + * If the semaphore object is of the type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC then, waiting on the semaphore + * will wait until the + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence is signaled + * by the signaler of the NvSciSyncObj that was associated with this semaphore + * object. By default, waiting on such an external semaphore object causes + * appropriate memory synchronization operations to be performed over all + * external memory objects that are imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. This ensures that any subsequent + * accesses made by other importers of the same set of NvSciBuf memory object(s) + * are coherent. These operations can be skipped by specifying the flag + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC, which can be used as a + * performance optimization when data coherency is not required. But specifying + * this flag in scenarios where data coherency is required results in undefined + * behavior. Also, for semaphore object of the type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, if the NvSciSyncAttrList used + * to create the NvSciSyncObj had not set the flags in + * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_WAIT, this API will + * return CUDA_ERROR_NOT_SUPPORTED. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then the keyed mutex will be acquired when it is released with the key + * specified in ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::keyedmutex::key + * or until the timeout specified by + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::keyedmutex::timeoutMs + * has lapsed. The timeout interval can either be a finite value + * specified in milliseconds or an infinite value. In case an infinite + * value is specified the timeout never elapses. The windows INFINITE + * macro must be used to specify infinite timeout. + * + * \param extSemArray - External semaphores to be waited on + * \param paramsArray - Array of semaphore parameters + * \param numExtSems - Number of semaphores to wait on + * \param stream - Stream to enqueue the wait operations in + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_TIMEOUT + * \notefnerr + * + * \sa ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync + */ +CUresult CUDAAPI cuWaitExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream); + +/** + * \brief Destroys an external semaphore + * + * Destroys an external semaphore object and releases any references + * to the underlying resource. Any outstanding signals or waits must + * have completed before the semaphore is destroyed. + * + * \param extSem - External semaphore to be destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); + +/** @} */ /* END CUDA_EXTRES_INTEROP */ + +/** + * \defgroup CUDA_MEMOP Stream Memory Operations + * + * ___MANBRIEF___ Stream memory operations of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream memory operations of the low-level CUDA + * driver application programming interface. + * + * Support for the ::CU_STREAM_WAIT_VALUE_NOR flag can be queried with + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR_V2. + * + * Support for the ::cuStreamWriteValue64() and ::cuStreamWaitValue64() + * functions, as well as for the ::CU_STREAM_MEM_OP_WAIT_VALUE_64 and + * ::CU_STREAM_MEM_OP_WRITE_VALUE_64 flags, can be queried with + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. + * + * Support for both ::CU_STREAM_WAIT_VALUE_FLUSH and + * ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES requires dedicated platform + * hardware features and can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES. + * + * Note that all memory pointers passed as parameters to these operations + * are device pointers. Where necessary a device pointer should be + * obtained, for example with ::cuMemHostGetDevicePointer(). + * + * None of the operations accepts pointers to managed memory buffers + * (::cuMemAllocManaged). + * + * \note + * Warning: + * Improper use of these APIs may deadlock the application. Synchronization + * ordering established through these APIs is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by these APIs should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. + * + * @{ + */ + +/** + * \brief Wait on a memory location + * + * Enqueues a synchronization of the stream on the given memory location. Work + * ordered after the operation will block until the given condition on the + * memory is satisfied. By default, the condition is to wait for + * (int32_t)(*addr - value) >= 0, a cyclic greater-or-equal. + * Other condition types can be specified via \p flags. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). This function cannot + * be used with managed memory (::cuMemAllocManaged). + * + * Support for CU_STREAM_WAIT_VALUE_NOR can be queried with + * ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR_V2. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. + * + * \param stream The stream to synchronize on the memory location. + * \param addr The memory location to wait on. + * \param value The value to compare with the memory location. + * \param flags See ::CUstreamWaitValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWaitValue64, + * ::cuStreamWriteValue32, + * ::cuStreamWriteValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); + +/** + * \brief Wait on a memory location + * + * Enqueues a synchronization of the stream on the given memory location. Work + * ordered after the operation will block until the given condition on the + * memory is satisfied. By default, the condition is to wait for + * (int64_t)(*addr - value) >= 0, a cyclic greater-or-equal. + * Other condition types can be specified via \p flags. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). + * + * Support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. + * + * \param stream The stream to synchronize on the memory location. + * \param addr The memory location to wait on. + * \param value The value to compare with the memory location. + * \param flags See ::CUstreamWaitValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWaitValue32, + * ::cuStreamWriteValue32, + * ::cuStreamWriteValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); + +/** + * \brief Write a value to memory + * + * Write a value to memory. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). This function cannot + * be used with managed memory (::cuMemAllocManaged). + * + * \param stream The stream to do the write in. + * \param addr The device address to write to. + * \param value The value to write. + * \param flags See ::CUstreamWriteValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWriteValue64, + * ::cuStreamWaitValue32, + * ::cuStreamWaitValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuEventRecord + */ +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); + +/** + * \brief Write a value to memory + * + * Write a value to memory. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). + * + * Support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. + * + * \param stream The stream to do the write in. + * \param addr The device address to write to. + * \param value The value to write. + * \param flags See ::CUstreamWriteValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWriteValue32, + * ::cuStreamWaitValue32, + * ::cuStreamWaitValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuEventRecord + */ +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); + +/** + * \brief Batch operations to synchronize the stream via memory operations + * + * This is a batch version of ::cuStreamWaitValue32() and + * ::cuStreamWriteValue32(). Batching operations may avoid some performance + * overhead in both the API call and the device execution versus adding them to + * the stream in separate API calls. The operations are enqueued in the order + * they appear in the array. + * + * See ::CUstreamBatchMemOpType for the full set of supported operations, and + * ::cuStreamWaitValue32(), ::cuStreamWaitValue64(), ::cuStreamWriteValue32(), + * and ::cuStreamWriteValue64() for details of specific operations. + * + * See related APIs for details on querying support for specific operations. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. For more + * information, see the Stream Memory Operations section in the programming + * guide(https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html). + * + * \param stream The stream to enqueue the operations in. + * \param count The number of operations in the array. Must be less than 256. + * \param paramArray The types and parameters of the individual operations. + * \param flags Reserved for future expansion; must be 0. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWaitValue32, + * ::cuStreamWaitValue64, + * ::cuStreamWriteValue32, + * ::cuStreamWriteValue64, + * ::cuMemHostRegister + */ +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags); + +/** @} */ /* END CUDA_MEMOP */ + +/** + * \defgroup CUDA_EXEC Execution Control + * + * ___MANBRIEF___ execution control functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the execution control functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns information about a function + * + * Returns in \p *pi the integer value of the attribute \p attrib on the kernel + * given by \p hfunc. The supported attributes are: + * - ::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK: The maximum number of threads + * per block, beyond which a launch of the function would fail. This number + * depends on both the function and the device on which the function is + * currently loaded. + * - ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES: The size in bytes of + * statically-allocated shared memory per block required by this function. + * This does not include dynamically-allocated shared memory requested by + * the user at runtime. + * - ::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES: The size in bytes of user-allocated + * constant memory required by this function. + * - ::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES: The size in bytes of local memory + * used by each thread of this function. + * - ::CU_FUNC_ATTRIBUTE_NUM_REGS: The number of registers used by each thread + * of this function. + * - ::CU_FUNC_ATTRIBUTE_PTX_VERSION: The PTX virtual architecture version for + * which the function was compiled. This value is the major PTX version * 10 + * + the minor PTX version, so a PTX version 1.3 function would return the + * value 13. Note that this may return the undefined value of 0 for cubins + * compiled prior to CUDA 3.0. + * - ::CU_FUNC_ATTRIBUTE_BINARY_VERSION: The binary architecture version for + * which the function was compiled. This value is the major binary + * version * 10 + the minor binary version, so a binary version 1.3 function + * would return the value 13. Note that this will return a value of 10 for + * legacy cubins that do not have a properly-encoded binary architecture + * version. + * - ::CU_FUNC_CACHE_MODE_CA: The attribute to indicate whether the function has + * been compiled with user specified option "-Xptxas --dlcm=ca" set . + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: The maximum size in + * bytes of dynamically-allocated shared memory. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: Preferred shared + * memory-L1 cache split ratio in percent of total shared memory. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET: If this attribute is set, the + * kernel must launch with a valid cluster size specified. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED: Indicates whether + * the function can be launched with non-portable cluster size. 1 is allowed, + * 0 is disallowed. A non-portable cluster size may only function on the + * specific SKUs the program is tested on. The launch might fail if the + * program is run on a different hardware platform. CUDA API provides + * cudaOccupancyMaxActiveClusters to assist with checking whether the desired + * size can be launched on the current device. A portable cluster size is + * guaranteed to be functional on all compute capabilities higher than the + * target compute capability. The portable cluster size for sm_90 is 8 blocks + * per cluster. This value may increase for future compute capabilities. The + * specific hardware unit may support higher cluster sizes that’s not + * guaranteed to be portable. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is + * CUclusterSchedulingPolicy. + * + * With a few exceptions, function attributes may also be queried on unloaded + * function handles returned from ::cuModuleEnumerateFunctions. + * ::CUDA_ERROR_FUNCTION_NOT_LOADED is returned if the attribute requires a + * fully loaded function but the function is not loaded. The loading state of a + * function may be queried using ::cuFuncIsloaded. ::cuFuncLoad may be called to + * explicitly load a function before querying the following attributes that + * require the function to be loaded: + * - ::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK + * - ::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES + * + * \param pi - Returned attribute value + * \param attrib - Attribute requested + * \param hfunc - Function to query attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_FUNCTION_NOT_LOADED + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuLaunchKernel, + * ::cudaFuncGetAttributes, + * ::cudaFuncSetAttribute, + * ::cuFuncIsLoaded, + * ::cuFuncLoad, + * ::cuKernelGetAttribute + */ +CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, + CUfunction hfunc); + +/** + * \brief Sets information about a function + * + * This call sets the value of a specified attribute \p attrib on the kernel + * given by \p hfunc to an integer value specified by \p val This function + * returns CUDA_SUCCESS if the new value of the attribute could be successfully + * set. If the set fails, this call will return an error. Not all attributes can + * have values set. Attempting to set a value on a read-only attribute will + * result in an error (CUDA_ERROR_INVALID_VALUE) + * + * Supported attributes for the cuFuncSetAttribute call are: + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: This maximum size in + * bytes of dynamically-allocated shared memory. The value should contain the + * requested maximum size of dynamically-allocated shared memory. The sum of + * this value and the function attribute ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES + * cannot exceed the device attribute + * ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN. The maximal size of + * requestable dynamic shared memory may differ by GPU architecture. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: On devices where the + * L1 cache and shared memory use the same hardware resources, this sets the + * shared memory carveout preference, in percent of the total shared memory. See + * ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR This is only a + * hint, and the driver can choose a different ratio if required to execute the + * function. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is + * CUclusterSchedulingPolicy. + * + * \param hfunc - Function to query attribute of + * \param attrib - Attribute requested + * \param value - The value to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuLaunchKernel, + * ::cudaFuncGetAttributes, + * ::cudaFuncSetAttribute, + * ::cuKernelSetAttribute + */ +CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, + CUfunction_attribute attrib, int value); + +/** + * \brief Sets the preferred cache configuration for a device function + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets through \p config the preferred cache configuration for + * the device function \p hfunc. This is only a preference. The driver will use + * the requested configuration if possible, but it is free to choose a different + * configuration if required to execute \p hfunc. Any context-wide preference + * set via ::cuCtxSetCacheConfig() will be overridden by this per-function + * setting unless the per-function setting is ::CU_FUNC_CACHE_PREFER_NONE. In + * that case, the current context-wide setting will be used. + * + * This setting does nothing on devices where the size of the L1 cache and + * shared memory are fixed. + * + * Launching a kernel with a different preference than the most recent + * preference setting may insert a device-side synchronization point. + * + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 + * (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 + * cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \param hfunc - Kernel to configure cache for + * \param config - Requested cache configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchKernel, + * ::cudaFuncSetCacheConfig, + * ::cuKernelSetCacheConfig + */ +CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); + +/** + * \brief Returns a module handle + * + * Returns in \p *hmod the handle of the module that function \p hfunc + * is located in. The lifetime of the module corresponds to the lifetime of + * the context it was loaded in or until the module is explicitly unloaded. + * + * The CUDA runtime manages its own modules loaded into the primary context. + * If the handle returned by this API refers to a module loaded by the CUDA + * runtime, calling ::cuModuleUnload() on that module will result in undefined + * behavior. + * + * \param hmod - Returned module handle + * \param hfunc - Function to retrieve module for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + */ +CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); + +/** + * \brief Returns the function name for a ::CUfunction handle + * + * Returns in \p **name the function name associated with the function handle \p + * hfunc . The function name is returned as a null-terminated string. The + * returned name is only valid when the function handle is valid. If the module + * is unloaded or reloaded, one must call the API again to get the updated name. + * This API may return a mangled name if the function is not declared as having + * C linkage. If either \p **name or \p hfunc is NULL, + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param name - The returned name of the function + * \param hfunc - The function handle to retrieve the name for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + */ +CUresult CUDAAPI cuFuncGetName(const char **name, CUfunction hfunc); + +/** + * \brief Returns the offset and size of a kernel parameter in the device-side + * parameter layout + * + * Queries the kernel parameter at \p paramIndex into \p func's list of + * parameters, and returns in \p paramOffset and \p paramSize the offset and + * size, respectively, where the parameter will reside in the device-side + * parameter layout. This information can be used to update kernel node + * parameters from the device via ::cudaGraphKernelNodeSetParam() and + * ::cudaGraphKernelNodeUpdatesApply(). \p paramIndex must be less than the + * number of parameters that \p func takes. \p paramSize can be set to NULL if + * only the parameter offset is desired. + * + * \param func - The function to query + * \param paramIndex - The parameter index to query + * \param paramOffset - Returns the offset into the device-side parameter layout + * at which the parameter resides \param paramSize - Optionally returns the + * size of the parameter in the device-side parameter layout + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuKernelGetParamInfo + */ +CUresult CUDAAPI cuFuncGetParamInfo(CUfunction func, size_t paramIndex, + size_t *paramOffset, size_t *paramSize); + +typedef enum CUfunctionLoadingState_enum { + CU_FUNCTION_LOADING_STATE_UNLOADED = 0, + CU_FUNCTION_LOADING_STATE_LOADED = 1, + CU_FUNCTION_LOADING_STATE_MAX +} CUfunctionLoadingState; + +/** + * \brief Returns if the function is loaded + * + * Returns in \p state the loading state of \p function. + * + * \param state - returned loading state + * \param function - the function to check + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuFuncLoad, + * ::cuModuleEnumerateFunctions + */ +CUresult CUDAAPI cuFuncIsLoaded(CUfunctionLoadingState *state, + CUfunction function); + +/** + * \brief Loads a function + * + * Finalizes function loading for \p function. Calling this API with a + * fully loaded function has no effect. + * + * \param function - the function to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuModuleEnumerateFunctions, + * ::cuFuncIsLoaded + */ +CUresult CUDAAPI cuFuncLoad(CUfunction function); + +/** + * \brief Launches a CUDA function ::CUfunction or a CUDA kernel ::CUkernel + * + * Invokes the function ::CUfunction or the kernel ::CUkernel \p f + * on a \p gridDimX x \p gridDimY x \p gridDimZ grid of blocks. + * Each block contains \p blockDimX x \p blockDimY x + * \p blockDimZ threads. + * + * \p sharedMemBytes sets the amount of dynamic shared memory that will be + * available to each thread block. + * + * Kernel parameters to \p f can be specified in one of two ways: + * + * 1) Kernel parameters can be specified via \p kernelParams. If \p f + * has N parameters, then \p kernelParams needs to be an array of N + * pointers. Each of \p kernelParams[0] through \p kernelParams[N-1] + * must point to a region of memory from which the actual kernel + * parameter will be copied. The number of kernel parameters and their + * offsets and sizes do not need to be specified as that information is + * retrieved directly from the kernel's image. + * + * 2) Kernel parameters can also be packaged by the application into + * a single buffer that is passed in via the \p extra parameter. + * This places the burden on the application of knowing each kernel + * parameter's size and alignment/padding within the buffer. Here is + * an example of using the \p extra parameter in this manner: + * \code + size_t argBufferSize; + char argBuffer[256]; + + // populate argBuffer and argBufferSize + + void *config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, argBuffer, + CU_LAUNCH_PARAM_BUFFER_SIZE, &argBufferSize, + CU_LAUNCH_PARAM_END + }; + status = cuLaunchKernel(f, gx, gy, gz, bx, by, bz, sh, s, NULL, config); + * \endcode + * + * The \p extra parameter exists to allow ::cuLaunchKernel to take + * additional less commonly used arguments. \p extra specifies a list of + * names of extra settings and their corresponding values. Each extra + * setting name is immediately followed by the corresponding value. The + * list must be terminated with either NULL or ::CU_LAUNCH_PARAM_END. + * + * - ::CU_LAUNCH_PARAM_END, which indicates the end of the \p extra + * array; + * - ::CU_LAUNCH_PARAM_BUFFER_POINTER, which specifies that the next + * value in \p extra will be a pointer to a buffer containing all + * the kernel parameters for launching kernel \p f; + * - ::CU_LAUNCH_PARAM_BUFFER_SIZE, which specifies that the next + * value in \p extra will be a pointer to a size_t containing the + * size of the buffer specified with ::CU_LAUNCH_PARAM_BUFFER_POINTER; + * + * The error ::CUDA_ERROR_INVALID_VALUE will be returned if kernel + * parameters are specified with both \p kernelParams and \p extra + * (i.e. both \p kernelParams and \p extra are non-NULL). + * + * Calling ::cuLaunchKernel() invalidates the persistent function state + * set through the following deprecated APIs: + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), + * ::cuParamSetv(). + * + * Note that to use ::cuLaunchKernel(), the kernel \p f must either have + * been compiled with toolchain version 3.2 or later so that it will + * contain kernel parameter information, or have no kernel parameters. + * If either of these conditions is not met, then ::cuLaunchKernel() will + * return ::CUDA_ERROR_INVALID_IMAGE. + * + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream \p hStream + * or the current context in case of NULL stream. + * + * \param f - Function ::CUfunction or Kernel ::CUkernel to launch + * \param gridDimX - Width of grid in blocks + * \param gridDimY - Height of grid in blocks + * \param gridDimZ - Depth of grid in blocks + * \param blockDimX - X dimension of each thread block + * \param blockDimY - Y dimension of each thread block + * \param blockDimZ - Z dimension of each thread block + * \param sharedMemBytes - Dynamic shared-memory size per thread block in bytes + * \param hStream - Stream identifier + * \param kernelParams - Array of pointers to kernel parameters + * \param extra - Extra options + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_NOT_FOUND + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cudaLaunchKernel, + * ::cuLibraryGetKernel, + * ::cuKernelSetCacheConfig, + * ::cuKernelGetAttribute, + * ::cuKernelSetAttribute + */ +CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra); + +/** + * \brief Launches a CUDA function ::CUfunction or a CUDA kernel ::CUkernel with + * launch-time configuration + * + * Invokes the function ::CUfunction or the kernel ::CUkernel \p f with the + * specified launch-time configuration \p config. + * + * The ::CUlaunchConfig structure is defined as: + * + * \code + * typedef struct CUlaunchConfig_st { + * unsigned int gridDimX; + * unsigned int gridDimY; + * unsigned int gridDimZ; + * unsigned int blockDimX; + * unsigned int blockDimY; + * unsigned int blockDimZ; + * unsigned int sharedMemBytes; + * CUstream hStream; + * CUlaunchAttribute *attrs; + * unsigned int numAttrs; + * } CUlaunchConfig; + * \endcode + * + * where: + * - ::CUlaunchConfig::gridDimX is the width of the grid in blocks. + * - ::CUlaunchConfig::gridDimY is the height of the grid in blocks. + * - ::CUlaunchConfig::gridDimZ is the depth of the grid in blocks. + * - ::CUlaunchConfig::blockDimX is the X dimension of each thread block. + * - ::CUlaunchConfig::blockDimX is the Y dimension of each thread block. + * - ::CUlaunchConfig::blockDimZ is the Z dimension of each thread block. + * - ::CUlaunchConfig::sharedMemBytes is the dynamic shared-memory size per + * thread block in bytes. + * - ::CUlaunchConfig::hStream is the handle to the stream to perform the launch + * in. The CUDA context associated with this stream must match that associated + * with function f. + * - ::CUlaunchConfig::attrs is an array of ::CUlaunchConfig::numAttrs + * contiguous ::CUlaunchAttribute elements. The value of this pointer is not + * considered if ::CUlaunchConfig::numAttrs is zero. However, in that case, it + * is recommended to set the pointer to NULL. + * - ::CUlaunchConfig::numAttrs is the number of attributes populating the + * first ::CUlaunchConfig::numAttrs positions of the ::CUlaunchConfig::attrs + * array. + * + * Launch-time configuration is specified by adding entries to + * ::CUlaunchConfig::attrs. Each entry is an attribute ID and a corresponding + * attribute value. + * + * The ::CUlaunchAttribute structure is defined as: + * \code + * typedef struct CUlaunchAttribute_st { + * CUlaunchAttributeID id; + * CUlaunchAttributeValue value; + * } CUlaunchAttribute; + * \endcode + * where: + * - ::CUlaunchAttribute::id is a unique enum identifying the attribute. + * - ::CUlaunchAttribute::value is a union that hold the attribute value. + * + * An example of using the \p config parameter: + * \code + * CUlaunchAttribute coopAttr = {.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, + * .value = 1}; + * CUlaunchConfig config = {... // set block and grid dimensions + * .attrs = &coopAttr, + * .numAttrs = 1}; + * + * cuLaunchKernelEx(&config, kernel, NULL, NULL); + * \endcode + * + * The ::CUlaunchAttributeID enum is defined as: + * \code + * typedef enum CUlaunchAttributeID_enum { + * CU_LAUNCH_ATTRIBUTE_IGNORE = 0, + * CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, + * CU_LAUNCH_ATTRIBUTE_COOPERATIVE = 2, + * CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY = 3, + * CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION = 4, + * CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = 5, + * CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION = 6, + * CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT = 7, + * CU_LAUNCH_ATTRIBUTE_PRIORITY = 8, + * CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP = 9, + * CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN = 10, + * CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT = 12, + * CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE = 13, + * } CUlaunchAttributeID; + * \endcode + * + * and the corresponding ::CUlaunchAttributeValue union as : + * \code + * typedef union CUlaunchAttributeValue_union { + * CUaccessPolicyWindow accessPolicyWindow; + * int cooperative; + * CUsynchronizationPolicy syncPolicy; + * struct { + * unsigned int x; + * unsigned int y; + * unsigned int z; + * } clusterDim; + * CUclusterSchedulingPolicy clusterSchedulingPolicyPreference; + * int programmaticStreamSerializationAllowed; + * struct { + * CUevent event; + * int flags; + * int triggerAtBlockStart; + * } programmaticEvent; + * int priority; + * CUlaunchMemSyncDomainMap memSyncDomainMap; + * CUlaunchMemSyncDomain memSyncDomain; + * struct { + * CUevent event; + * int flags; + * } launchCompletionEvent; + * struct { + * int deviceUpdatable; + * CUgraphDeviceNode devNode; + * } deviceUpdatableKernelNode; + * } CUlaunchAttributeValue; + * \endcode + * + * Setting ::CU_LAUNCH_ATTRIBUTE_COOPERATIVE to a non-zero value causes the + * kernel launch to be a cooperative launch, with exactly the same usage and + * semantics of ::cuLaunchCooperativeKernel. + * + * Setting ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION to a non-zero + * values causes the kernel to use programmatic means to resolve its stream + * dependency -- enabling the CUDA runtime to opportunistically allow the grid's + * execution to overlap with the previous kernel in the stream, if that kernel + * requests the overlap. + * + * ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT records an event along with the + * kernel launch. Event recorded through this launch attribute is guaranteed to + * only trigger after all block in the associated kernel trigger the event. A + * block can trigger the event through PTX launchdep.release or CUDA builtin + * function cudaTriggerProgrammaticLaunchCompletion(). A trigger can also be + * inserted at the beginning of each block's execution if triggerAtBlockStart is + * set to non-0. Note that dependents (including the CPU thread calling + * cuEventSynchronize()) are not guaranteed to observe the release precisely + * when it is released. For example, cuEventSynchronize() may only observe the + * event trigger long after the associated kernel has completed. This recording + * type is primarily meant for establishing programmatic dependency between + * device tasks. The event supplied must not be an interprocess or interop + * event. The event must disable timing (i.e. created with + * ::CU_EVENT_DISABLE_TIMING flag set). + * + * ::CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT records an event along with + * the kernel launch. Nominally, the event is triggered once all blocks of the + * kernel have begun execution. Currently this is a best effort. If a kernel B + * has a launch completion dependency on a kernel A, B may wait until A is + * complete. Alternatively, blocks of B may begin before all blocks of A have + * begun, for example: + * + * - If B can claim execution resources unavailable to A, for example if they + * run on different GPUs. + * - If B is a higher priority than A. + * + * Exercise caution if such an ordering inversion could lead to deadlock. The + * event supplied must not be an interprocess or interop event. The event must + * disable timing (i.e. must be created with the ::CU_EVENT_DISABLE_TIMING flag + * set). + * + * Setting ::CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE to 1 + * on a captured launch causes the resulting kernel node to be device-updatable. + * This attribute is specific to graphs, and passing it to a launch in a + * non-capturing stream results in an error. Passing a value other than 0 or 1 + * is not allowed. + * + * On success, a handle will be returned via + * ::CUlaunchAttributeValue::deviceUpdatableKernelNode::devNode which can be + * passed to the various device-side update functions to update the node's + * kernel parameters from within another kernel. For more information on the + * types of device updates that can be made, as well as the relevant limitations + * thereof, see + * ::cudaGraphKernelNodeUpdatesApply. + * + * Kernel nodes which are device-updatable have additional restrictions compared + * to regular kernel nodes. Firstly, device-updatable nodes cannot be removed + * from their graph via + * ::cuGraphDestroyNode. Additionally, once opted-in to this functionality, a + * node cannot opt out, and any attempt to set the attribute to 0 will result in + * an error. Graphs containing one or more device-updatable node also do not + * allow multiple instantiation. + * + * + * The effect of other attributes is consistent with their effect when set via + * persistent APIs. + * + * See ::cuStreamSetAttribute for + * - ::CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW + * - ::CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY + * + * See ::cuFuncSetAttribute for + * - ::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + * - ::CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE + * + * Kernel parameters to \p f can be specified in the same ways that they can be + * using ::cuLaunchKernel. + * + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream + * ::CUlaunchConfig::hStream or the current context in case of NULL stream. + * + * \param config - Config to launch + * \param f - Function ::CUfunction or Kernel ::CUkernel to launch + * \param kernelParams - Array of pointers to kernel parameters + * \param extra - Extra options + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_NOT_FOUND + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cudaLaunchKernel, + * ::cudaLaunchKernelEx, + * ::cuLibraryGetKernel, + * ::cuKernelSetCacheConfig, + * ::cuKernelGetAttribute, + * ::cuKernelSetAttribute + */ +CUresult CUDAAPI cuLaunchKernelEx(const CUlaunchConfig *config, CUfunction f, + void **kernelParams, void **extra); + +/** + * \brief Launches a CUDA function ::CUfunction or a CUDA kernel ::CUkernel + * where thread blocks can cooperate and synchronize as they execute + * + * Invokes the function ::CUfunction or the kernel ::CUkernel \p f on a \p + * gridDimX x \p gridDimY x \p gridDimZ grid of blocks. Each block contains \p + * blockDimX x \p blockDimY x \p blockDimZ threads. + * + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream \p hStream + * or the current context in case of NULL stream. + * + * \p sharedMemBytes sets the amount of dynamic shared memory that will be + * available to each thread block. + * + * The device on which this kernel is invoked must have a non-zero value for + * the device attribute ::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH. + * + * The total number of blocks launched cannot exceed the maximum number of + * blocks per multiprocessor as returned by + * ::cuOccupancyMaxActiveBlocksPerMultiprocessor (or + * ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags) times the number of + * multiprocessors as specified by the device attribute + * ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT. + * + * The kernel cannot make use of CUDA dynamic parallelism. + * + * Kernel parameters must be specified via \p kernelParams. If \p f + * has N parameters, then \p kernelParams needs to be an array of N + * pointers. Each of \p kernelParams[0] through \p kernelParams[N-1] + * must point to a region of memory from which the actual kernel + * parameter will be copied. The number of kernel parameters and their + * offsets and sizes do not need to be specified as that information is + * retrieved directly from the kernel's image. + * + * Calling ::cuLaunchCooperativeKernel() sets persistent function state that is + * the same as function state set through ::cuLaunchKernel API + * + * When the kernel \p f is launched via ::cuLaunchCooperativeKernel(), the + * previous block shape, shared size and parameter info associated with \p f is + * overwritten. + * + * Note that to use ::cuLaunchCooperativeKernel(), the kernel \p f must either + * have been compiled with toolchain version 3.2 or later so that it will + * contain kernel parameter information, or have no kernel parameters. + * If either of these conditions is not met, then ::cuLaunchCooperativeKernel() + * will return ::CUDA_ERROR_INVALID_IMAGE. + * + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream \p hStream + * or the current context in case of NULL stream. + * + * \param f - Function ::CUfunction or Kernel ::CUkernel to launch + * \param gridDimX - Width of grid in blocks + * \param gridDimY - Height of grid in blocks + * \param gridDimZ - Depth of grid in blocks + * \param blockDimX - X dimension of each thread block + * \param blockDimY - Y dimension of each thread block + * \param blockDimZ - Z dimension of each thread block + * \param sharedMemBytes - Dynamic shared-memory size per thread block in bytes + * \param hStream - Stream identifier + * \param kernelParams - Array of pointers to kernel parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_NOT_FOUND + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchCooperativeKernelMultiDevice, + * ::cudaLaunchCooperativeKernel, + * ::cuLibraryGetKernel, + * ::cuKernelSetCacheConfig, + * ::cuKernelGetAttribute, + * ::cuKernelSetAttribute + */ +CUresult CUDAAPI cuLaunchCooperativeKernel( + CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams); + +/** + * \brief Launches CUDA functions on multiple devices where thread blocks can + cooperate and synchronize as they execute + * + * \deprecated This function is deprecated as of CUDA 11.3. + * + * Invokes kernels as specified in the \p launchParamsList array where each + element + * of the array specifies all the parameters required to perform a single kernel + launch. + * These kernels can cooperate and synchronize as they execute. The size of the + array is + * specified by \p numDevices. + * + * No two kernels can be launched on the same device. All the devices targeted + by this + * multi-device launch must be identical. All devices must have a non-zero value + for the + * device attribute ::CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH. + * + * All kernels launched must be identical with respect to the compiled code. + Note that + * any __device__, __constant__ or __managed__ variables present in the module + that owns + * the kernel launched on each device, are independently instantiated on every + device. + * It is the application's responsibility to ensure these variables are + initialized and + * used appropriately. + * + * The size of the grids as specified in blocks, the size of the blocks + themselves + * and the amount of shared memory used by each thread block must also match + across + * all launched kernels. + * + * The streams used to launch these kernels must have been created via either + ::cuStreamCreate + * or ::cuStreamCreateWithPriority. The NULL stream or ::CU_STREAM_LEGACY or + ::CU_STREAM_PER_THREAD + * cannot be used. + * + * The total number of blocks launched per kernel cannot exceed the maximum + number of blocks + * per multiprocessor as returned by + ::cuOccupancyMaxActiveBlocksPerMultiprocessor (or + * ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags) times the number of + multiprocessors + * as specified by the device attribute + ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT. Since the + * total number of blocks launched per device has to match across all devices, + the maximum + * number of blocks that can be launched per device will be limited by the + device with the + * least number of multiprocessors. + * + * The kernels cannot make use of CUDA dynamic parallelism. + * + * The ::CUDA_LAUNCH_PARAMS structure is defined as: + * \code + typedef struct CUDA_LAUNCH_PARAMS_st + { + CUfunction function; + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + unsigned int sharedMemBytes; + CUstream hStream; + void **kernelParams; + } CUDA_LAUNCH_PARAMS; + * \endcode + * where: + * - ::CUDA_LAUNCH_PARAMS::function specifies the kernel to be launched. All + functions must + * be identical with respect to the compiled code. + * Note that you can also specify context-less kernel ::CUkernel by querying + the handle + * using ::cuLibraryGetKernel() and then casting to ::CUfunction. In this + case, the context to + * launch the kernel on be taken from the specified stream + ::CUDA_LAUNCH_PARAMS::hStream. + * - ::CUDA_LAUNCH_PARAMS::gridDimX is the width of the grid in blocks. This + must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::gridDimY is the height of the grid in blocks. This + must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::gridDimZ is the depth of the grid in blocks. This + must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::blockDimX is the X dimension of each thread block. + This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::blockDimX is the Y dimension of each thread block. + This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::blockDimZ is the Z dimension of each thread block. + This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::sharedMemBytes is the dynamic shared-memory size per + thread block in bytes. + * This must match across all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::hStream is the handle to the stream to perform the + launch in. This cannot + * be the NULL stream or ::CU_STREAM_LEGACY or ::CU_STREAM_PER_THREAD. The + CUDA context associated + * with this stream must match that associated with + ::CUDA_LAUNCH_PARAMS::function. + * - ::CUDA_LAUNCH_PARAMS::kernelParams is an array of pointers to kernel + parameters. If + * ::CUDA_LAUNCH_PARAMS::function has N parameters, then + ::CUDA_LAUNCH_PARAMS::kernelParams + * needs to be an array of N pointers. Each of + ::CUDA_LAUNCH_PARAMS::kernelParams[0] through + * ::CUDA_LAUNCH_PARAMS::kernelParams[N-1] must point to a region of memory + from which the actual + * kernel parameter will be copied. The number of kernel parameters and their + offsets and sizes + * do not need to be specified as that information is retrieved directly from + the kernel's image. + * + * By default, the kernel won't begin execution on any GPU until all prior work + in all the specified + * streams has completed. This behavior can be overridden by specifying the flag + * ::CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_PRE_LAUNCH_SYNC. When this flag is + specified, each kernel + * will only wait for prior work in the stream corresponding to that GPU to + complete before it begins + * execution. + * + * Similarly, by default, any subsequent work pushed in any of the specified + streams will not begin + * execution until the kernels on all GPUs have completed. This behavior can be + overridden by specifying + * the flag ::CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_POST_LAUNCH_SYNC. When + this flag is specified, + * any subsequent work pushed in any of the specified streams will only wait for + the kernel launched + * on the GPU corresponding to that stream to complete before it begins + execution. + * + * Calling ::cuLaunchCooperativeKernelMultiDevice() sets persistent function + state that is + * the same as function state set through ::cuLaunchKernel API when called + individually for each + * element in \p launchParamsList. + * + * When kernels are launched via ::cuLaunchCooperativeKernelMultiDevice(), the + previous + * block shape, shared size and parameter info associated with each + ::CUDA_LAUNCH_PARAMS::function + * in \p launchParamsList is overwritten. + * + * Note that to use ::cuLaunchCooperativeKernelMultiDevice(), the kernels must + either have + * been compiled with toolchain version 3.2 or later so that it will + * contain kernel parameter information, or have no kernel parameters. + * If either of these conditions is not met, then + ::cuLaunchCooperativeKernelMultiDevice() will + * return ::CUDA_ERROR_INVALID_IMAGE. + * + * \param launchParamsList - List of launch parameters, one per device + * \param numDevices - Size of the \p launchParamsList array + * \param flags - Flags to control launch behavior + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchCooperativeKernel, + * ::cudaLaunchCooperativeKernelMultiDevice + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice( + CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, + unsigned int flags); + +/** + * \brief Enqueues a host function call in a stream + * + * Enqueues a host function to run in a stream. The function will be called + * after currently enqueued work and will block work added after it. + * + * The host function must not make any CUDA API calls. Attempting to use a + * CUDA API may result in ::CUDA_ERROR_NOT_PERMITTED, but this is not required. + * The host function must not perform any synchronization that may depend on + * outstanding CUDA work not mandated to run earlier. Host functions without a + * mandated order (such as in independent streams) execute in undefined order + * and may be serialized. + * + * For the purposes of Unified Memory, execution makes a number of guarantees: + *
    + *
  • The stream is considered idle for the duration of the function's + * execution. Thus, for example, the function may always use memory attached + * to the stream it was enqueued in.
  • + *
  • The start of execution of the function has the same effect as + * synchronizing an event recorded in the same stream immediately prior to + * the function. It thus synchronizes streams which have been "joined" + * prior to the function.
  • + *
  • Adding device work to any stream does not have the effect of making + * the stream active until all preceding host functions and stream callbacks + * have executed. Thus, for + * example, a function might use global attached memory even if work has + * been added to another stream, if the work has been ordered behind the + * function call with an event.
  • + *
  • Completion of the function does not cause a stream to become + * active except as described above. The stream will remain idle + * if no device work follows the function, and will remain idle across + * consecutive host functions or stream callbacks without device work in + * between. Thus, for example, + * stream synchronization can be done by signaling from a host function at the + * end of the stream.
  • + *
+ * + * Note that, in contrast to ::cuStreamAddCallback, the function will not be + * called in the event of an error in the CUDA context. + * + * \param hStream - Stream to enqueue function call in + * \param fn - The function to call once preceding stream operations are + * complete \param userData - User-specified data to be passed to the function + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuMemAllocManaged, + * ::cuStreamAttachMemAsync, + * ::cuStreamAddCallback + */ +CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, + void *userData); + +/** @} */ /* END CUDA_EXEC */ + +/** + * \defgroup CUDA_EXEC_DEPRECATED Execution Control [DEPRECATED] + * + * ___MANBRIEF___ deprecated execution control functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated execution control functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Sets the block-dimensions for the function + * + * \deprecated + * + * Specifies the \p x, \p y, and \p z dimensions of the thread blocks that are + * created when the kernel given by \p hfunc is launched. + * + * \param hfunc - Kernel to specify dimensions of + * \param x - X dimension + * \param y - Y dimension + * \param z - Z dimension + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetSharedSize, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSeti, + * ::cuParamSetf, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetBlockShape(CUfunction hfunc, int x, + int y, int z); + +/** + * \brief Sets the dynamic shared-memory size for the function + * + * \deprecated + * + * Sets through \p bytes the amount of dynamic shared memory that will be + * available to each thread block when the kernel given by \p hfunc is launched. + * + * \param hfunc - Kernel to specify dynamic shared-memory size for + * \param bytes - Dynamic shared-memory size per thread in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSeti, + * ::cuParamSetf, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetSharedSize(CUfunction hfunc, + unsigned int bytes); + +/** + * \brief Sets the parameter size for the function + * + * \deprecated + * + * Sets through \p numbytes the total size in bytes needed by the function + * parameters of the kernel corresponding to \p hfunc. + * + * \param hfunc - Kernel to set parameter size for + * \param numbytes - Size of parameter list in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetSize(CUfunction hfunc, + unsigned int numbytes); + +/** + * \brief Adds an integer parameter to the function's argument list + * + * \deprecated + * + * Sets an integer parameter that will be specified the next time the + * kernel corresponding to \p hfunc will be invoked. \p offset is a byte offset. + * + * \param hfunc - Kernel to add parameter to + * \param offset - Offset to add parameter to argument list + * \param value - Value of parameter + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSeti(CUfunction hfunc, int offset, + unsigned int value); + +/** + * \brief Adds a floating-point parameter to the function's argument list + * + * \deprecated + * + * Sets a floating-point parameter that will be specified the next time the + * kernel corresponding to \p hfunc will be invoked. \p offset is a byte offset. + * + * \param hfunc - Kernel to add parameter to + * \param offset - Offset to add parameter to argument list + * \param value - Value of parameter + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetf(CUfunction hfunc, int offset, + float value); + +/** + * \brief Adds arbitrary data to the function's argument list + * + * \deprecated + * + * Copies an arbitrary amount of data (specified in \p numbytes) from \p ptr + * into the parameter space of the kernel corresponding to \p hfunc. \p offset + * is a byte offset. + * + * \param hfunc - Kernel to add data to + * \param offset - Offset to add data to argument list + * \param ptr - Pointer to arbitrary data + * \param numbytes - Size of data to copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, + void *ptr, + unsigned int numbytes); + +/** + * \brief Launches a CUDA function + * + * \deprecated + * + * Invokes the kernel \p f on a 1 x 1 x 1 grid of blocks. The block + * contains the number of threads specified by a previous call to + * ::cuFuncSetBlockShape(). + * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * + * \param f - Kernel to launch + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunch(CUfunction f); + +/** + * \brief Launches a CUDA function + * + * \deprecated + * + * Invokes the kernel \p f on a \p grid_width x \p grid_height grid of + * blocks. Each block contains the number of threads specified by a previous + * call to ::cuFuncSetBlockShape(). + * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * + * \param f - Kernel to launch + * \param grid_width - Width of grid in blocks + * \param grid_height - Height of grid in blocks + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, + int grid_height); + +/** + * \brief Launches a CUDA function + * + * \deprecated + * + * Invokes the kernel \p f on a \p grid_width x \p grid_height grid of + * blocks. Each block contains the number of threads specified by a previous + * call to ::cuFuncSetBlockShape(). + * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * + * \param f - Kernel to launch + * \param grid_width - Width of grid in blocks + * \param grid_height - Height of grid in blocks + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * + * \note In certain cases where cubins are created with no ABI (i.e., using \p + * ptxas \p --abi-compile \p no), this function may serialize kernel launches. + * The CUDA driver retains asynchronous behavior by growing the per-thread stack + * as needed per launch and not shrinking it afterwards. + * + * \note_null_stream + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, + int grid_width, + int grid_height, + CUstream hStream); + +/** + * \brief Adds a texture-reference to the function's argument list + * + * \deprecated + * + * Makes the CUDA array or linear memory bound to the texture reference + * \p hTexRef available to a device program as a texture. In this version of + * CUDA, the texture-reference must be obtained via ::cuModuleGetTexRef() and + * the \p texunit parameter must be set to ::CU_PARAM_TR_DEFAULT. + * + * \param hfunc - Kernel to add texture-reference to + * \param texunit - Texture unit (must be ::CU_PARAM_TR_DEFAULT) + * \param hTexRef - Texture-reference to add to argument list + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, + int texunit, + CUtexref hTexRef); + +/** + * \brief Sets the shared memory configuration for a device function. + * + * \deprecated + * + * On devices with configurable shared memory banks, this function will + * force all subsequent launches of the specified device function to have + * the given shared memory bank size configuration. On any given launch of the + * function, the shared memory configuration of the device will be temporarily + * changed if needed to suit the function's preferred configuration. Changes in + * shared memory configuration between subsequent launches of functions, + * may introduce a device side synchronization point. + * + * Any per-function setting of shared memory bank size set via + * ::cuFuncSetSharedMemConfig will override the context wide setting set with + * ::cuCtxSetSharedMemConfig. + * + * Changing the shared memory bank size will not increase shared memory usage + * or affect occupancy of kernels, but may have major effects on performance. + * Larger bank sizes will allow for greater potential bandwidth to shared + * memory, but will change what kinds of accesses to shared memory will result + * in bank conflicts. + * + * This function will do nothing on devices with fixed shared memory bank size. + * + * The supported bank configurations are: + * - ::CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE: use the context's shared memory + * configuration when launching this function. + * - ::CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE: set shared memory bank width to + * be natively four bytes when launching this function. + * - ::CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE: set shared memory bank width + * to be natively eight bytes when launching this function. + * + * \param hfunc - kernel to be given a shared memory config + * \param config - requested shared memory configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuCtxGetSharedMemConfig, + * ::cuCtxSetSharedMemConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchKernel, + * ::cudaFuncSetSharedMemConfig + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuFuncSetSharedMemConfig(CUfunction hfunc, CUsharedconfig config); + +/** @} */ /* END CUDA_EXEC_DEPRECATED */ + +/** + * \defgroup CUDA_GRAPH Graph Management + * + * ___MANBRIEF___ graph management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the graph management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Creates a graph + * + * Creates an empty graph, which is returned via \p phGraph. + * + * \param phGraph - Returns newly created graph + * \param flags - Graph creation flags, must be 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + * ::cuGraphInstantiate, + * ::cuGraphDestroy, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphClone + */ +CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags); + +/** + * \brief Creates a kernel execution node and adds it to a graph + * + * Creates a new kernel execution node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * The CUDA_KERNEL_NODE_PARAMS structure is defined as: + * + * \code + * typedef struct CUDA_KERNEL_NODE_PARAMS_st { + * CUfunction func; + * unsigned int gridDimX; + * unsigned int gridDimY; + * unsigned int gridDimZ; + * unsigned int blockDimX; + * unsigned int blockDimY; + * unsigned int blockDimZ; + * unsigned int sharedMemBytes; + * void **kernelParams; + * void **extra; + * CUkernel kern; + * CUcontext ctx; + * } CUDA_KERNEL_NODE_PARAMS; + * \endcode + * + * When the graph is launched, the node will invoke kernel \p func on a (\p + * gridDimX x \p gridDimY x \p gridDimZ) grid of blocks. Each block contains + * (\p blockDimX x \p blockDimY x \p blockDimZ) threads. + * + * \p sharedMemBytes sets the amount of dynamic shared memory that will be + * available to each thread block. + * + * Kernel parameters to \p func can be specified in one of two ways: + * + * 1) Kernel parameters can be specified via \p kernelParams. If the kernel has + * N parameters, then \p kernelParams needs to be an array of N pointers. Each + * pointer, from \p kernelParams[0] to \p kernelParams[N-1], points to the + * region of memory from which the actual parameter will be copied. The number + * of kernel parameters and their offsets and sizes do not need to be specified + * as that information is retrieved directly from the kernel's image. + * + * 2) Kernel parameters for non-cooperative kernels can also be packaged by the + * application into a single buffer that is passed in via \p extra. This places + * the burden on the application of knowing each kernel parameter's size and + * alignment/padding within the buffer. The \p extra parameter exists to allow + * this function to take additional less commonly used arguments. \p extra + * specifies a list of names of extra settings and their corresponding values. + * Each extra setting name is immediately followed by the corresponding value. + * The list must be terminated with either NULL or CU_LAUNCH_PARAM_END. + * + * - ::CU_LAUNCH_PARAM_END, which indicates the end of the \p extra + * array; + * - ::CU_LAUNCH_PARAM_BUFFER_POINTER, which specifies that the next + * value in \p extra will be a pointer to a buffer + * containing all the kernel parameters for launching kernel + * \p func; + * - ::CU_LAUNCH_PARAM_BUFFER_SIZE, which specifies that the next + * value in \p extra will be a pointer to a size_t + * containing the size of the buffer specified with + * ::CU_LAUNCH_PARAM_BUFFER_POINTER; + * + * The error ::CUDA_ERROR_INVALID_VALUE will be returned if kernel parameters + * are specified with both \p kernelParams and \p extra (i.e. both \p + * kernelParams and \p extra are non-NULL). + * ::CUDA_ERROR_INVALID_VALUE will be returned if \p extra is used for a + * cooperative kernel. + * + * The \p kernelParams or \p extra array, as well as the argument values it + * points to, are copied during this call. + * + * \note Kernels launched using graphs must not use texture and surface + * references. Reading or writing through any texture or surface reference is + * undefined behavior. This restriction does not apply to texture and surface + * objects. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the GPU execution node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuLaunchKernel, + * ::cuLaunchCooperativeKernel, + * ::cuGraphKernelNodeGetParams, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddKernelNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a kernel node's parameters + * + * Returns the parameters of kernel node \p hNode in \p nodeParams. + * The \p kernelParams or \p extra array returned in \p nodeParams, + * as well as the argument values it points to, are owned by the node. + * This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphKernelNodeSetParams to update the + * parameters of this node. + * + * The params will contain either \p kernelParams or \p extra, + * according to which of these was most recently set on the node. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeSetParams + */ +CUresult CUDAAPI cuGraphKernelNodeGetParams( + CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets a kernel node's parameters + * + * Sets the parameters of kernel node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuLaunchKernel, + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeGetParams + */ +CUresult CUDAAPI cuGraphKernelNodeSetParams( + CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a memcpy node and adds it to a graph + * + * Creates a new memcpy node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be + * placed at the root of the graph. \p dependencies may not have any duplicate + * entries. A handle to the new node will be returned in \p phGraphNode. + * + * When the graph is launched, the node will perform the memcpy described by \p + * copyParams. See ::cuMemcpy3D() for a description of the structure and its + * restrictions. + * + * Memcpy nodes have some additional restrictions with regards to managed + * memory, if the system contains at least one device which has a zero value for + * the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. If one or more of the + * operands refer to managed memory, then using the memory type + * ::CU_MEMORYTYPE_UNIFIED is disallowed for those operand(s). The managed + * memory will be treated as residing on either the host or the device, + * depending on which memory type is specified. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param copyParams - Parameters for the memory copy + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuMemcpy3D, + * ::cuGraphMemcpyNodeGetParams, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx); + +/** + * \brief Returns a memcpy node's parameters + * + * Returns the parameters of memcpy node \p hNode in \p nodeParams. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemcpy3D, + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeSetParams + */ +CUresult CUDAAPI cuGraphMemcpyNodeGetParams(CUgraphNode hNode, + CUDA_MEMCPY3D *nodeParams); + +/** + * \brief Sets a memcpy node's parameters + * + * Sets the parameters of memcpy node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuMemcpy3D, + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeGetParams + */ +CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, + const CUDA_MEMCPY3D *nodeParams); + +/** + * \brief Creates a memset node and adds it to a graph + * + * Creates a new memset node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be + * placed at the root of the graph. \p dependencies may not have any duplicate + * entries. A handle to the new node will be returned in \p phGraphNode. + * + * The element size must be 1, 2, or 4 bytes. + * When the graph is launched, the node will perform the memset described by \p + * memsetParams. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param memsetParams - Parameters for the memory set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CONTEXT + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuMemsetD2D32, + * ::cuGraphMemsetNodeGetParams, + * ::cuGraphMemsetNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode + */ +CUresult CUDAAPI cuGraphAddMemsetNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, + CUcontext ctx); + +/** + * \brief Returns a memset node's parameters + * + * Returns the parameters of memset node \p hNode in \p nodeParams. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemsetD2D32, + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeSetParams + */ +CUresult CUDAAPI cuGraphMemsetNodeGetParams( + CUgraphNode hNode, CUDA_MEMSET_NODE_PARAMS *nodeParams); + +/** + * \brief Sets a memset node's parameters + * + * Sets the parameters of memset node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuMemsetD2D32, + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeGetParams + */ +CUresult CUDAAPI cuGraphMemsetNodeSetParams( + CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a host execution node and adds it to a graph + * + * Creates a new CPU execution node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * When the graph is launched, the node will invoke the specified CPU function. + * Host nodes are not supported under MPS with pre-Volta GPUs. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the host node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuLaunchHostFunc, + * ::cuGraphHostNodeGetParams, + * ::cuGraphHostNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a host node's parameters + * + * Returns the parameters of host node \p hNode in \p nodeParams. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchHostFunc, + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeSetParams + */ +CUresult CUDAAPI cuGraphHostNodeGetParams(CUgraphNode hNode, + CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Sets a host node's parameters + * + * Sets the parameters of host node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuLaunchHostFunc, + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeGetParams + */ +CUresult CUDAAPI cuGraphHostNodeSetParams( + CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a child graph node and adds it to a graph + * + * Creates a new node which executes an embedded graph, and adds it to \p hGraph + * with \p numDependencies dependencies specified via \p dependencies. It is + * possible for \p numDependencies to be 0, in which case the node will be + * placed at the root of the graph. \p dependencies may not have any duplicate + * entries. A handle to the new node will be returned in \p phGraphNode. + * + * If \p hGraph contains allocation or free nodes, this call will return an + * error. + * + * The node executes an embedded child graph. The child graph is cloned in this + * call. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param childGraph - The graph to clone into this node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + * ::cuGraphClone + */ +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUgraph childGraph); + +/** + * \brief Gets a handle to the embedded graph of a child graph node + * + * Gets a handle to the embedded graph in a child graph node. This call + * does not clone the graph. Changes to the graph will be reflected in + * the node, and the node retains ownership of the graph. + * + * Allocation and free nodes cannot be added to the returned graph. + * Attempting to do so will return an error. + * + * \param hNode - Node to get the embedded graph for + * \param phGraph - Location to store a handle to the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphNodeFindInClone + */ +CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, + CUgraph *phGraph); + +/** + * \brief Creates an empty node and adds it to a graph + * + * Creates a new node which performs no operation, and adds it to \p hGraph with + * \p numDependencies dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be + * placed at the root of the graph. \p dependencies may not have any duplicate + * entries. A handle to the new node will be returned in \p phGraphNode. + * + * An empty node performs no operation during execution, but can be used for + * transitive ordering. For example, a phased execution graph with 2 groups of n + * nodes with a barrier between them can be represented using an empty node and + * 2*n dependency edges, rather than no empty node and n^2 dependency edges. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies); + +/** + * \brief Creates an event record node and adds it to a graph + * + * Creates a new event record node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and event + * specified in \p event. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * Each launch of the graph will record \p event to capture execution of the + * node's dependencies. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param event - Event for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphAddEventWaitNode, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddEventRecordNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUevent event); + +/** + * \brief Returns the event associated with an event record node + * + * Returns the event of event record node \p hNode in \p event_out. + * + * \param hNode - Node to get the event for + * \param event_out - Pointer to return the event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventRecordNodeGetEvent(CUgraphNode hNode, + CUevent *event_out); + +/** + * \brief Sets an event record node's event + * + * Sets the event of event record node \p hNode to \p event. + * + * \param hNode - Node to set the event for + * \param event - Event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventRecordNodeSetEvent(CUgraphNode hNode, + CUevent event); + +/** + * \brief Creates an event wait node and adds it to a graph + * + * Creates a new event wait node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and event + * specified in \p event. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * The graph node will wait for all work captured in \p event. See + * ::cuEventRecord() for details on what is captured by an event. \p event may + * be from a different context or device than the launch stream. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param event - Event for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphAddEventRecordNode, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddEventWaitNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, CUevent event); + +/** + * \brief Returns the event associated with an event wait node + * + * Returns the event of event wait node \p hNode in \p event_out. + * + * \param hNode - Node to get the event for + * \param event_out - Pointer to return the event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventWaitNodeGetEvent(CUgraphNode hNode, + CUevent *event_out); + +/** + * \brief Sets an event wait node's event + * + * Sets the event of event wait node \p hNode to \p event. + * + * \param hNode - Node to set the event for + * \param event - Event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventWaitNodeSetEvent(CUgraphNode hNode, CUevent event); + +/** + * \brief Creates an external semaphore signal node and adds it to a graph + * + * Creates a new external semaphore signal node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * Performs a signal operation on a set of externally allocated semaphore + * objects when the node is launched. The operation(s) will occur after all of + * the node's dependencies have completed. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphExternalSemaphoresSignalNodeGetParams, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddExternalSemaphoresSignalNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Returns an external semaphore signal node's parameters + * + * Returns the parameters of an external semaphore signal node \p hNode in \p + * params_out. The \p extSemArray and \p paramsArray returned in \p params_out, + * are owned by the node. This memory remains valid until the node is destroyed + * or its parameters are modified, and should not be modified directly. Use + * ::cuGraphExternalSemaphoresSignalNodeSetParams to update the parameters of + * this node. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeGetParams( + CUgraphNode hNode, CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *params_out); + +/** + * \brief Sets an external semaphore signal node's parameters + * + * Sets the parameters of an external semaphore signal node \p hNode to \p + * nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeSetParams( + CUgraphNode hNode, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Creates an external semaphore wait node and adds it to a graph + * + * Creates a new external semaphore wait node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * Performs a wait operation on a set of externally allocated semaphore objects + * when the node is launched. The node's dependencies will not be launched + * until the wait operation has completed. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphExternalSemaphoresWaitNodeGetParams, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddExternalSemaphoresWaitNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Returns an external semaphore wait node's parameters + * + * Returns the parameters of an external semaphore wait node \p hNode in \p + * params_out. The \p extSemArray and \p paramsArray returned in \p params_out, + * are owned by the node. This memory remains valid until the node is destroyed + * or its parameters are modified, and should not be modified directly. Use + * ::cuGraphExternalSemaphoresSignalNodeSetParams to update the parameters of + * this node. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeGetParams( + CUgraphNode hNode, CUDA_EXT_SEM_WAIT_NODE_PARAMS *params_out); + +/** + * \brief Sets an external semaphore wait node's parameters + * + * Sets the parameters of an external semaphore wait node \p hNode to \p + * nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams( + CUgraphNode hNode, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a batch memory operation node and adds it to a graph + * + * Creates a new batch memory operation node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * When the node is added, the paramArray inside \p nodeParams is copied and + * therefore it can be freed after the call returns. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. For more + * information, see the Stream Memory Operations section in the programming + * guide(https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html). + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuStreamBatchMemOp, + * ::cuStreamWaitValue32, + * ::cuStreamWriteValue32, + * ::cuStreamWaitValue64, + * ::cuStreamWriteValue64, + * ::cuGraphBatchMemOpNodeGetParams, + * ::cuGraphBatchMemOpNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddBatchMemOpNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a batch mem op node's parameters + * + * Returns the parameters of batch mem op node \p hNode in \p nodeParams_out. + * The \p paramArray returned in \p nodeParams_out is owned by the node. + * This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphBatchMemOpNodeSetParams to update the + * parameters of this node. + * + * \param hNode - Node to get the parameters for + * \param nodeParams_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamBatchMemOp, + * ::cuGraphAddBatchMemOpNode, + * ::cuGraphBatchMemOpNodeSetParams + */ +CUresult CUDAAPI cuGraphBatchMemOpNodeGetParams( + CUgraphNode hNode, CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams_out); + +/** + * \brief Sets a batch mem op node's parameters + * + * Sets the parameters of batch mem op node \p hNode to \p nodeParams. + * + * The paramArray inside \p nodeParams is copied and therefore it can be + * freed after the call returns. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetParams, + * ::cuStreamBatchMemOp, + * ::cuGraphAddBatchMemOpNode, + * ::cuGraphBatchMemOpNodeGetParams + */ +CUresult CUDAAPI cuGraphBatchMemOpNodeSetParams( + CUgraphNode hNode, const CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for a batch mem op node in the given graphExec + * + * Sets the parameters of a batch mem op node in an executable graph \p + * hGraphExec. The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The following fields on operations may be modified on an executable graph: + * + * op.waitValue.address + * op.waitValue.value[64] + * op.waitValue.flags bits corresponding to wait type (i.e. + * CU_STREAM_WAIT_VALUE_FLUSH bit cannot be modified) op.writeValue.address + * op.writeValue.value[64] + * + * Other fields, such as the context, count or type of operations, and other + * types of operations such as membars, may not be modified. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * The paramArray inside \p nodeParams is copied and therefore it can be + * freed after the call returns. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Batch mem op node from the graph from which graphExec was + * instantiated \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuStreamBatchMemOp, + * ::cuGraphAddBatchMemOpNode, + * ::cuGraphBatchMemOpNodeGetParams, + * ::cuGraphBatchMemOpNodeSetParams, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecBatchMemOpNodeSetParams( + CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams); + +/** + * \brief Creates an allocation node and adds it to a graph + * + * Creates a new allocation node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * When ::cuGraphAddMemAllocNode creates an allocation node, it returns the + * address of the allocation in \p nodeParams.dptr. The allocation's address + * remains fixed across instantiations and launches. + * + * If the allocation is freed in the same graph, by creating a free node using + * ::cuGraphAddMemFreeNode, the allocation can be accessed by nodes ordered + * after the allocation node but before the free node. These allocations cannot + * be freed outside the owning graph, and they can only be freed once in the + * owning graph. + * + * If the allocation is not freed in the same graph, then it can be accessed not + * only by nodes in the graph which are ordered after the allocation node, but + * also by stream operations ordered after the graph's execution but before the + * allocation is freed. + * + * Allocations which are not freed in the same graph can be freed by: + * - passing the allocation to ::cuMemFreeAsync or ::cuMemFree; + * - launching a graph with a free node for that allocation; or + * - specifying ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH during + * instantiation, which makes each launch behave as though it called + * ::cuMemFreeAsync for every unfreed allocation. + * + * It is not possible to free an allocation in both the owning graph and another + * graph. If the allocation is freed in the same graph, a free node cannot be + * added to another graph. If the allocation is freed in another graph, a free + * node can no longer be added to the owning graph. + * + * The following restrictions apply to graphs which contain allocation and/or + * memory free nodes: + * - Nodes and edges of the graph cannot be deleted. + * - The graph cannot be used in a child node. + * - Only one instantiation of the graph may exist at any point in time. + * - The graph cannot be cloned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphAddMemFreeNode, + * ::cuGraphMemAllocNodeGetParams, + * ::cuDeviceGraphMemTrim, + * ::cuDeviceGetGraphMemAttribute, + * ::cuDeviceSetGraphMemAttribute, + * ::cuMemAllocAsync, + * ::cuMemFreeAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemAllocNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUDA_MEM_ALLOC_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a memory alloc node's parameters + * + * Returns the parameters of a memory alloc node \p hNode in \p params_out. + * The \p poolProps and \p accessDescs returned in \p params_out, are owned by + * the node. This memory remains valid until the node is destroyed. The + * returned parameters must not be modified. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphMemFreeNodeGetParams + */ +CUresult CUDAAPI cuGraphMemAllocNodeGetParams( + CUgraphNode hNode, CUDA_MEM_ALLOC_NODE_PARAMS *params_out); + +/** + * \brief Creates a memory free node and adds it to a graph + * + * Creates a new memory free node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments + * specified in \p nodeParams. It is possible for \p numDependencies to be 0, in + * which case the node will be placed at the root of the graph. \p dependencies + * may not have any duplicate entries. A handle to the new node will be returned + * in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param dptr - Address of memory to free + * + * ::cuGraphAddMemFreeNode will return ::CUDA_ERROR_INVALID_VALUE if the user + * attempts to free: + * - an allocation twice in the same graph. + * - an address that was not returned by an allocation node. + * - an invalid address. + * + * The following restrictions apply to graphs which contain allocation and/or + * memory free nodes: + * - Nodes and edges of the graph cannot be deleted. + * - The graph cannot be used in a child node. + * - Only one instantiation of the graph may exist at any point in time. + * - The graph cannot be cloned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphAddMemAllocNode, + * ::cuGraphMemFreeNodeGetParams, + * ::cuDeviceGraphMemTrim, + * ::cuDeviceGetGraphMemAttribute, + * ::cuDeviceSetGraphMemAttribute, + * ::cuMemAllocAsync, + * ::cuMemFreeAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemFreeNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUdeviceptr dptr); + +/** + * \brief Returns a memory free node's parameters + * + * Returns the address of a memory free node \p hNode in \p dptr_out. + * + * \param hNode - Node to get the parameters for + * \param dptr_out - Pointer to return the device address + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemFreeNode, + * ::cuGraphMemAllocNodeGetParams + */ +CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, + CUdeviceptr *dptr_out); + +/** + * \brief Free unused memory that was cached on the specified device for use + * with graphs back to the OS. + * + * Blocks which are not in use by a graph that is either currently executing or + * scheduled to execute are freed back to the operating system. + * + * \param device - The device for which cached memory should be freed. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode, + * ::cuDeviceSetGraphMemAttribute, + * ::cuDeviceGetGraphMemAttribute + */ +CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); + +/** + * \brief Query asynchronous allocation attributes related to graphs + * + * Valid attributes are: + * + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT: Amount of memory, in bytes, currently + * associated with graphs + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_HIGH: High watermark of memory, in bytes, + * associated with graphs since the last time it was reset. High watermark can + * only be reset to zero. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT: Amount of memory, in bytes, + * currently allocated for use by the CUDA graphs asynchronous allocator. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH: High watermark of memory, in bytes, + * currently allocated for use by the CUDA graphs asynchronous allocator. + * + * \param device - Specifies the scope of the query + * \param attr - attribute to get + * \param value - retrieved value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDeviceSetGraphMemAttribute, + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode + */ +CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, + CUgraphMem_attribute attr, + void *value); + +/** + * \brief Set asynchronous allocation attributes related to graphs + * + * Valid attributes are: + * + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_HIGH: High watermark of memory, in bytes, + * associated with graphs since the last time it was reset. High watermark can + * only be reset to zero. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH: High watermark of memory, in bytes, + * currently allocated for use by the CUDA graphs asynchronous allocator. + * + * \param device - Specifies the scope of the query + * \param attr - attribute to get + * \param value - pointer to value to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDeviceGetGraphMemAttribute, + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode + */ +CUresult CUDAAPI cuDeviceSetGraphMemAttribute(CUdevice device, + CUgraphMem_attribute attr, + void *value); + +/** + * \brief Clones a graph + * + * This function creates a copy of \p originalGraph and returns it in \p + * phGraphClone. All parameters are copied into the cloned graph. The original + * graph may be modified after this call without affecting the clone. + * + * Child graph nodes in the original graph are recursively copied into the + * clone. + * + * \param phGraphClone - Returns newly created cloned graph + * \param originalGraph - Graph to clone + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphNodeFindInClone + */ +CUresult CUDAAPI cuGraphClone(CUgraph *phGraphClone, CUgraph originalGraph); + +/** + * \brief Finds a cloned version of a node + * + * This function returns the node in \p hClonedGraph corresponding to \p + * hOriginalNode in the original graph. + * + * \p hClonedGraph must have been cloned from \p hOriginalGraph via + * ::cuGraphClone. \p hOriginalNode must have been in \p hOriginalGraph at the + * time of the call to + * ::cuGraphClone, and the corresponding cloned node in \p hClonedGraph must not + * have been removed. The cloned node is then returned via \p phClonedNode. + * + * \param phNode - Returns handle to the cloned node + * \param hOriginalNode - Handle to the original node + * \param hClonedGraph - Cloned graph to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphClone + */ +CUresult CUDAAPI cuGraphNodeFindInClone(CUgraphNode *phNode, + CUgraphNode hOriginalNode, + CUgraph hClonedGraph); + +/** + * \brief Returns a node's type + * + * Returns the node type of \p hNode in \p type. + * + * \param hNode - Node to query + * \param type - Pointer to return the node type + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphKernelNodeGetParams, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphHostNodeGetParams, + * ::cuGraphHostNodeSetParams, + * ::cuGraphMemcpyNodeGetParams, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphMemsetNodeGetParams, + * ::cuGraphMemsetNodeSetParams + */ +CUresult CUDAAPI cuGraphNodeGetType(CUgraphNode hNode, CUgraphNodeType *type); + +/** + * \brief Returns a graph's nodes + * + * Returns a list of \p hGraph's nodes. \p nodes may be NULL, in which case this + * function will return the number of nodes in \p numNodes. Otherwise, + * \p numNodes entries will be filled in. If \p numNodes is higher than the + * actual number of nodes, the remaining entries in \p nodes will be set to + * NULL, and the number of nodes actually obtained will be returned in \p + * numNodes. + * + * \param hGraph - Graph to query + * \param nodes - Pointer to return the nodes + * \param numNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetType, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetNodes(CUgraph hGraph, CUgraphNode *nodes, + size_t *numNodes); + +/** + * \brief Returns a graph's root nodes + * + * Returns a list of \p hGraph's root nodes. \p rootNodes may be NULL, in which + * case this function will return the number of root nodes in \p numRootNodes. + * Otherwise, \p numRootNodes entries will be filled in. If \p numRootNodes is + * higher than the actual number of root nodes, the remaining entries in \p + * rootNodes will be set to NULL, and the number of nodes actually obtained will + * be returned in \p numRootNodes. + * + * \param hGraph - Graph to query + * \param rootNodes - Pointer to return the root nodes + * \param numRootNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphGetNodes, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetType, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetRootNodes(CUgraph hGraph, CUgraphNode *rootNodes, + size_t *numRootNodes); + +/** + * \brief Returns a graph's dependency edges + * + * Returns a list of \p hGraph's dependency edges. Edges are returned via + * corresponding indices in \p from and \p to; that is, the node in \p to[i] has + * a dependency on the node in \p from[i]. \p from and \p to may both be NULL, + * in which case this function only returns the number of edges in \p numEdges. + * Otherwise, \p numEdges entries will be filled in. If \p numEdges is higher + * than the actual number of edges, the remaining entries in \p from and \p to + * will be set to NULL, and the number of edges actually returned will be + * written to \p numEdges. + * + * \param hGraph - Graph to get the edges from + * \param from - Location to return edge endpoints + * \param to - Location to return edge endpoints + * \param numEdges - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetEdges(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, size_t *numEdges); + +/** + * \brief Returns a graph's dependency edges (12.3+) + * + * Returns a list of \p hGraph's dependency edges. Edges are returned via + * corresponding indices in \p from, \p to and \p edgeData; that is, the node in + * \p to[i] has a dependency on the node in \p from[i] with data \p edgeData[i]. + * \p from and \p to may both be NULL, in which case this function only returns + * the number of edges in \p numEdges. Otherwise, \p numEdges entries will be + * filled in. If \p numEdges is higher than the actual number of edges, the + * remaining entries in \p from and \p to will be set to NULL, and the number of + * edges actually returned will be written to \p numEdges. \p edgeData may alone + * be NULL, in which case the edges must all have default (zeroed) edge data. + * Attempting a lossy query via NULL \p edgeData will result in + * ::CUDA_ERROR_LOSSY_QUERY. If \p edgeData is non-NULL then \p from and \p to + * must be as well. + * + * \param hGraph - Graph to get the edges from + * \param from - Location to return edge endpoints + * \param to - Location to return edge endpoints + * \param edgeData - Optional location to return edge data + * \param numEdges - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_LOSSY_QUERY, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetEdges_v2(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, CUgraphEdgeData *edgeData, + size_t *numEdges); + +/** + * \brief Returns a node's dependencies + * + * Returns a list of \p node's dependencies. \p dependencies may be NULL, in + * which case this function will return the number of dependencies in \p + * numDependencies. Otherwise, \p numDependencies entries will be filled in. If + * \p numDependencies is higher than the actual number of dependencies, the + * remaining entries in \p dependencies will be set to NULL, and the number of + * nodes actually obtained will be returned in \p numDependencies. + * + * \param hNode - Node to query + * \param dependencies - Pointer to return the dependencies + * \param numDependencies - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetDependentNodes, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies + */ +CUresult CUDAAPI cuGraphNodeGetDependencies(CUgraphNode hNode, + CUgraphNode *dependencies, + size_t *numDependencies); + +/** + * \brief Returns a node's dependencies (12.3+) + * + * Returns a list of \p node's dependencies. \p dependencies may be NULL, in + * which case this function will return the number of dependencies in \p + * numDependencies. Otherwise, \p numDependencies entries will be filled in. If + * \p numDependencies is higher than the actual number of dependencies, the + * remaining entries in \p dependencies will be set to NULL, and the number of + * nodes actually obtained will be returned in \p numDependencies. + * + * Note that if an edge has non-zero (non-default) edge data and \p edgeData is + * NULL, this API will return ::CUDA_ERROR_LOSSY_QUERY. If \p edgeData is + * non-NULL, then \p dependencies must be as well. + * + * \param hNode - Node to query + * \param dependencies - Pointer to return the dependencies + * \param edgeData - Optional array to return edge data for each + * dependency \param numDependencies - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_LOSSY_QUERY, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetDependentNodes, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies + */ +CUresult CUDAAPI cuGraphNodeGetDependencies_v2(CUgraphNode hNode, + CUgraphNode *dependencies, + CUgraphEdgeData *edgeData, + size_t *numDependencies); + +/** + * \brief Returns a node's dependent nodes + * + * Returns a list of \p node's dependent nodes. \p dependentNodes may be NULL, + * in which case this function will return the number of dependent nodes in \p + * numDependentNodes. Otherwise, \p numDependentNodes entries will be filled in. + * If \p numDependentNodes is higher than the actual number of dependent nodes, + * the remaining entries in \p dependentNodes will be set to NULL, and the + * number of nodes actually obtained will be returned in \p numDependentNodes. + * + * \param hNode - Node to query + * \param dependentNodes - Pointer to return the dependent nodes + * \param numDependentNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetDependencies, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies + */ +CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, + CUgraphNode *dependentNodes, + size_t *numDependentNodes); + +/** + * \brief Returns a node's dependent nodes (12.3+) + * + * Returns a list of \p node's dependent nodes. \p dependentNodes may be NULL, + * in which case this function will return the number of dependent nodes in \p + * numDependentNodes. Otherwise, \p numDependentNodes entries will be filled in. + * If \p numDependentNodes is higher than the actual number of dependent nodes, + * the remaining entries in \p dependentNodes will be set to NULL, and the + * number of nodes actually obtained will be returned in \p numDependentNodes. + * + * Note that if an edge has non-zero (non-default) edge data and \p edgeData is + * NULL, this API will return ::CUDA_ERROR_LOSSY_QUERY. If \p edgeData is + * non-NULL, then \p dependentNodes must be as well. + * + * \param hNode - Node to query + * \param dependentNodes - Pointer to return the dependent nodes + * \param edgeData - Optional pointer to return edge data for dependent + * nodes \param numDependentNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_LOSSY_QUERY, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetDependencies, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies + */ +CUresult CUDAAPI cuGraphNodeGetDependentNodes_v2(CUgraphNode hNode, + CUgraphNode *dependentNodes, + CUgraphEdgeData *edgeData, + size_t *numDependentNodes); + +/** + * \brief Adds dependency edges to a graph + * + * The number of dependencies to be added is defined by \p numDependencies + * Elements in \p from and \p to at corresponding indices define a dependency. + * Each node in \p from and \p to must belong to \p hGraph. + * + * If \p numDependencies is 0, elements in \p from and \p to will be ignored. + * Specifying an existing dependency will return an error. + * + * \param hGraph - Graph to which dependencies are added + * \param from - Array of nodes that provide the dependencies + * \param to - Array of dependent nodes + * \param numDependencies - Number of dependencies to be added + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphRemoveDependencies, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies); + +/** + * \brief Adds dependency edges to a graph (12.3+) + * + * The number of dependencies to be added is defined by \p numDependencies + * Elements in \p from and \p to at corresponding indices define a dependency. + * Each node in \p from and \p to must belong to \p hGraph. + * + * If \p numDependencies is 0, elements in \p from and \p to will be ignored. + * Specifying an existing dependency will return an error. + * + * \param hGraph - Graph to which dependencies are added + * \param from - Array of nodes that provide the dependencies + * \param to - Array of dependent nodes + * \param edgeData - Optional array of edge data. If NULL, default (zeroed) edge + * data is assumed. \param numDependencies - Number of dependencies to be added + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphRemoveDependencies, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphAddDependencies_v2(CUgraph hGraph, + const CUgraphNode *from, + const CUgraphNode *to, + const CUgraphEdgeData *edgeData, + size_t numDependencies); + +/** + * \brief Removes dependency edges from a graph + * + * The number of \p dependencies to be removed is defined by \p numDependencies. + * Elements in \p from and \p to at corresponding indices define a dependency. + * Each node in \p from and \p to must belong to \p hGraph. + * + * If \p numDependencies is 0, elements in \p from and \p to will be ignored. + * Specifying a non-existing dependency will return an error. + * + * Dependencies cannot be removed from graphs which contain allocation or free + * nodes. Any attempt to do so will return an error. + * + * \param hGraph - Graph from which to remove dependencies + * \param from - Array of nodes that provide the dependencies + * \param to - Array of dependent nodes + * \param numDependencies - Number of dependencies to be removed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddDependencies, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, + const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies); + +/** + * \brief Removes dependency edges from a graph (12.3+) + * + * The number of \p dependencies to be removed is defined by \p numDependencies. + * Elements in \p from and \p to at corresponding indices define a dependency. + * Each node in \p from and \p to must belong to \p hGraph. + * + * If \p numDependencies is 0, elements in \p from and \p to will be ignored. + * Specifying an edge that does not exist in the graph, with data matching + * \p edgeData, results in an error. \p edgeData is nullable, which is + * equivalent to passing default (zeroed) data for each edge. + * + * Dependencies cannot be removed from graphs which contain allocation or free + * nodes. Any attempt to do so will return an error. + * + * \param hGraph - Graph from which to remove dependencies + * \param from - Array of nodes that provide the dependencies + * \param to - Array of dependent nodes + * \param edgeData - Optional array of edge data. If NULL, edge data is assumed + * to be default (zeroed). \param numDependencies - Number of dependencies to be + * removed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddDependencies, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphRemoveDependencies_v2(CUgraph hGraph, + const CUgraphNode *from, + const CUgraphNode *to, + const CUgraphEdgeData *edgeData, + size_t numDependencies); + +/** + * \brief Remove a node from the graph + * + * Removes \p hNode from its graph. This operation also severs any dependencies + * of other nodes on \p hNode and vice versa. + * + * Nodes which belong to a graph which contains allocation or free nodes cannot + * be destroyed. Any attempt to do so will return an error. + * + * \param hNode - Node to remove + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode); + +/** + * \brief Creates an executable graph from a graph + * + * Instantiates \p hGraph as an executable graph. The graph is validated for any + * structural constraints or intra-node constraints which were not previously + * validated. If instantiation is successful, a handle to the instantiated graph + * is returned in \p phGraphExec. + * + * The \p flags parameter controls the behavior of instantiation and subsequent + * graph launches. Valid flags are: + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, which configures a + * graph containing memory allocation nodes to automatically free any + * unfreed memory allocations before the graph is relaunched. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH, which configures the graph for + * launch from the device. If this flag is passed, the executable graph handle + * returned can be used to launch the graph from both the host and device. This + * flag can only be used on platforms which support unified addressing. This + * flag cannot be used in conjunction with + * ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY, which causes the graph + * to use the priorities from the per-node attributes rather than the priority + * of the launch stream during execution. Note that priorities are only + * available on kernel nodes, and are copied from stream priority during stream + * capture. + * + * If \p hGraph contains any allocation or free nodes, there can be at most one + * executable graph in existence for that graph at a time. An attempt to + * instantiate a second executable graph before destroying the first with + * ::cuGraphExecDestroy will result in an error. The same also applies if \p + * hGraph contains any device-updatable kernel nodes. + * + * If \p hGraph contains kernels which call device-side cudaGraphLaunch() from + * multiple contexts, this will result in an error. + * + * Graphs instantiated for launch on the device have additional restrictions + * which do not apply to host graphs: + * + * - The graph's nodes must reside on a single context. + * - The graph can only contain kernel nodes, memcpy nodes, memset nodes, and + * child graph nodes. + * - The graph cannot be empty and must contain at least one kernel, memcpy, or + * memset node. Operation-specific restrictions are outlined below. + * - Kernel nodes: + * - Use of CUDA Dynamic Parallelism is not permitted. + * - Cooperative launches are permitted as long as MPS is not in use. + * - Memcpy nodes: + * - Only copies involving device memory and/or pinned device-mapped host + * memory are permitted. + * - Copies involving CUDA arrays are not permitted. + * - Both operands must be accessible from the current context, and the + * current context must match the context of other nodes in the graph. + * + * \param phGraphExec - Returns instantiated graph + * \param hGraph - Graph to instantiate + * \param flags - Flags to control instantiation. See + * ::CUgraphInstantiate_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphCreate, + * ::cuGraphUpload, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, + unsigned long long flags); + +/** + * \brief Creates an executable graph from a graph + * + * Instantiates \p hGraph as an executable graph according to the \p + instantiateParams structure. + * The graph is validated for any structural constraints or intra-node + constraints + * which were not previously validated. If instantiation is successful, a handle + to + * the instantiated graph is returned in \p phGraphExec. + * + * \p instantiateParams controls the behavior of instantiation and subsequent + * graph launches, as well as returning more detailed information in the event + of an error. + * ::CUDA_GRAPH_INSTANTIATE_PARAMS is defined as: + * + * \code + typedef struct { + cuuint64_t flags; + CUstream hUploadStream; + CUgraphNode hErrNode_out; + CUgraphInstantiateResult result_out; + } CUDA_GRAPH_INSTANTIATE_PARAMS; + * \endcode + * + * The \p flags field controls the behavior of instantiation and subsequent + * graph launches. Valid flags are: + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, which configures a + * graph containing memory allocation nodes to automatically free any + * unfreed memory allocations before the graph is relaunched. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD, which will perform an upload of the + graph + * into \p hUploadStream once the graph has been instantiated. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH, which configures the graph for + launch + * from the device. If this flag is passed, the executable graph handle returned + can be + * used to launch the graph from both the host and device. This flag can only be + used + * on platforms which support unified addressing. This flag cannot be used in + * conjunction with ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY, which causes the graph + * to use the priorities from the per-node attributes rather than the priority + * of the launch stream during execution. Note that priorities are only + available + * on kernel nodes, and are copied from stream priority during stream capture. + * + * If \p hGraph contains any allocation or free nodes, there can be at most one + * executable graph in existence for that graph at a time. An attempt to + instantiate a + * second executable graph before destroying the first with ::cuGraphExecDestroy + will + * result in an error. + * The same also applies if \p hGraph contains any device-updatable kernel + nodes. + * + * If \p hGraph contains kernels which call device-side cudaGraphLaunch() from + multiple + * contexts, this will result in an error. + * + * Graphs instantiated for launch on the device have additional restrictions + which do not + * apply to host graphs: + * + * - The graph's nodes must reside on a single context. + * - The graph can only contain kernel nodes, memcpy nodes, memset nodes, and + child graph nodes. + * - The graph cannot be empty and must contain at least one kernel, memcpy, or + memset node. + * Operation-specific restrictions are outlined below. + * - Kernel nodes: + * - Use of CUDA Dynamic Parallelism is not permitted. + * - Cooperative launches are permitted as long as MPS is not in use. + * - Memcpy nodes: + * - Only copies involving device memory and/or pinned device-mapped host + memory are permitted. + * - Copies involving CUDA arrays are not permitted. + * - Both operands must be accessible from the current context, and the + current context must + * match the context of other nodes in the graph. + * + * In the event of an error, the \p result_out and \p hErrNode_out fields will + contain more + * information about the nature of the error. Possible error reporting includes: + * + * - ::CUDA_GRAPH_INSTANTIATE_ERROR, if passed an invalid value or if an + unexpected error occurred + * which is described by the return value of the function. \p hErrNode_out + will be set to NULL. + * - ::CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE, if the graph structure is + invalid. \p hErrNode_out + * will be set to one of the offending nodes. + * - ::CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED, if the graph is + instantiated for device + * launch but contains a node of an unsupported node type, or a node which + performs unsupported + * operations, such as use of CUDA dynamic parallelism within a kernel node. + \p hErrNode_out will + * be set to this node. + * - ::CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED, if the graph is + instantiated for device + * launch but a node’s context differs from that of another node. This error + can also be returned + * if a graph is not instantiated for device launch and it contains kernels + which call device-side + * cudaGraphLaunch() from multiple contexts. \p hErrNode_out will be set to + this node. + * + * If instantiation is successful, \p result_out will be set to + ::CUDA_GRAPH_INSTANTIATE_SUCCESS, + * and \p hErrNode_out will be set to NULL. + * + * \param phGraphExec - Returns instantiated graph + * \param hGraph - Graph to instantiate + * \param instantiateParams - Instantiation parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphInstantiate, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI +cuGraphInstantiateWithParams(CUgraphExec *phGraphExec, CUgraph hGraph, + CUDA_GRAPH_INSTANTIATE_PARAMS *instantiateParams); + +/** + * \brief Query the instantiation flags of an executable graph + * + * Returns the flags that were passed to instantiation for the given executable + * graph. + * ::CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD will not be returned by this API as it + * does not affect the resulting executable graph. + * + * \param hGraphExec - The executable graph to query + * \param flags - Returns the instantiation flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphInstantiateWithParams + */ +CUresult CUDAAPI cuGraphExecGetFlags(CUgraphExec hGraphExec, cuuint64_t *flags); + +/** + * \brief Sets the parameters for a kernel node in the given graphExec + * + * Sets the parameters of a kernel node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. All \p + * nodeParams fields may change, but the following restrictions apply to \p func + * updates: + * + * - The owning context of the function cannot change. + * - A node whose function originally did not use CUDA dynamic parallelism + * cannot be updated to a function which uses CDP + * - A node whose function originally did not make device-side update calls + * cannot be updated to a function which makes device-side update calls. + * - If \p hGraphExec was not instantiated for device launch, a node whose + * function originally did not use device-side cudaGraphLaunch() cannot be + * updated to a function which uses device-side cudaGraphLaunch() unless the + * node resides on the same context as nodes which contained such calls at + * instantiate-time. If no such calls were present at instantiation, these + * updates cannot be performed at all. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * If \p hNode is a device-updatable kernel node, the next upload/launch of \p + * hGraphExec will overwrite any previous device-side updates. Additionally, + * applying host updates to a device-updatable kernel node while it is being + * updated from the device will result in undefined behavior. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - kernel node from the graph from which graphExec was + * instantiated \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI +cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for a memcpy node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode + * had contained \p copyParams at instantiation. hNode must remain in the graph + * which was used to instantiate \p hGraphExec. Changed edges to and from hNode + * are ignored. + * + * The source and destination memory in \p copyParams must be allocated from the + * same contexts as the original source and destination memory. Both the + * instantiation-time memory operands and the memory operands in \p copyParams + * must be 1-dimensional. Zero-length operations are not supported. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * hNode is also not modified by this call. + * + * Returns CUDA_ERROR_INVALID_VALUE if the memory operands' mappings changed or + * either the original or new memory operands are multidimensional. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Memcpy node from the graph which was used to instantiate + * graphExec \param copyParams - The updated parameters to set \param ctx - + * Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecMemcpyNodeSetParams(CUgraphExec hGraphExec, + CUgraphNode hNode, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx); + +/** + * \brief Sets the parameters for a memset node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode + * had contained \p memsetParams at instantiation. hNode must remain in the + * graph which was used to instantiate \p hGraphExec. Changed edges to and from + * hNode are ignored. + * + * The destination memory in \p memsetParams must be allocated from the same + * contexts as the original destination memory. Both the instantiation-time + * memory operand and the memory operand in \p memsetParams must be + * 1-dimensional. Zero-length operations are not supported. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * hNode is also not modified by this call. + * + * Returns CUDA_ERROR_INVALID_VALUE if the memory operand's mappings changed or + * either the original or new memory operand are multidimensional. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Memset node from the graph which was used to + * instantiate graphExec \param memsetParams - The updated parameters to set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecMemsetNodeSetParams( + CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx); + +/** + * \brief Sets the parameters for a host node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode + * had contained \p nodeParams at instantiation. hNode must remain in the graph + * which was used to instantiate \p hGraphExec. Changed edges to and from hNode + * are ignored. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Host node from the graph which was used to instantiate + * graphExec \param nodeParams - The updated parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI +cuGraphExecHostNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Updates node parameters in the child graph node in the given + * graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though the nodes + * contained in \p hNode's graph had the parameters contained in \p childGraph's + * nodes at instantiation. \p hNode must remain in the graph which was used to + * instantiate \p hGraphExec. Changed edges to and from \p hNode are ignored. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * The topology of \p childGraph, as well as the node insertion order, must + * match that of the graph contained in \p hNode. See ::cuGraphExecUpdate() for + * a list of restrictions on what can be updated in an instantiated graph. The + * update is recursive, so child graph nodes contained within the top level + * child graph will also be updated. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Host node from the graph which was used to instantiate + * graphExec \param childGraph - The graph supplying the updated parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddChildGraphNode, + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecChildGraphNodeSetParams(CUgraphExec hGraphExec, + CUgraphNode hNode, + CUgraph childGraph); + +/** + * \brief Sets the event for an event record node in the given graphExec + * + * Sets the event of an event record node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - event record node from the graph from which graphExec was + * instantiated \param event - Updated event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecEventRecordNodeSetEvent(CUgraphExec hGraphExec, + CUgraphNode hNode, + CUevent event); + +/** + * \brief Sets the event for an event wait node in the given graphExec + * + * Sets the event of an event wait node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - event wait node from the graph from which graphExec was + * instantiated \param event - Updated event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecEventWaitNodeSetEvent(CUgraphExec hGraphExec, + CUgraphNode hNode, + CUevent event); + +/** + * \brief Sets the parameters for an external semaphore signal node in the given + * graphExec + * + * Sets the parameters of an external semaphore signal node in an executable + * graph \p hGraphExec. The node is identified by the corresponding node \p + * hNode in the non-executable graph, from which the executable graph was + * instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Changing \p nodeParams->numExtSems is not supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - semaphore signal node from the graph from which graphExec + * was instantiated \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecExternalSemaphoresSignalNodeSetParams( + CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for an external semaphore wait node in the given + * graphExec + * + * Sets the parameters of an external semaphore wait node in an executable graph + * \p hGraphExec. The node is identified by the corresponding node \p hNode in + * the non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Changing \p nodeParams->numExtSems is not supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - semaphore wait node from the graph from which graphExec + * was instantiated \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExecNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecExternalSemaphoresWaitNodeSetParams( + CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Enables or disables the specified node in the given graphExec + * + * Sets \p hNode to be either enabled or disabled. Disabled nodes are + * functionally equivalent to empty nodes until they are re-enabled. Existing + * node parameters are not affected by disabling/enabling the node. + * + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * If \p hNode is a device-updatable kernel node, the next upload/launch of \p + * hGraphExec will overwrite any previous device-side updates. Additionally, + * applying host updates to a device-updatable kernel node while it is being + * updated from the device will result in undefined behavior. + * + * \note Currently only kernel, memset and memcpy nodes are supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Node from the graph from which graphExec was instantiated + * \param isEnabled - Node is enabled if != 0, otherwise the node is disabled + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetEnabled, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + * ::cuGraphLaunch + */ +CUresult CUDAAPI cuGraphNodeSetEnabled(CUgraphExec hGraphExec, + CUgraphNode hNode, + unsigned int isEnabled); + +/** + * \brief Query whether a node in the given graphExec is enabled + * + * Sets isEnabled to 1 if \p hNode is enabled, or 0 if \p hNode is disabled. + * + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * \note Currently only kernel, memset and memcpy nodes are supported. + * \note This function will not reflect device-side updates for device-updatable + * kernel nodes. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Node from the graph from which graphExec was instantiated + * \param isEnabled - Location to return the enabled status of the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetEnabled, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + * ::cuGraphLaunch + */ +CUresult CUDAAPI cuGraphNodeGetEnabled(CUgraphExec hGraphExec, + CUgraphNode hNode, + unsigned int *isEnabled); + +/** + * \brief Uploads an executable graph in a stream + * + * Uploads \p hGraphExec to the device in \p hStream without executing it. + * Uploads of the same \p hGraphExec will be serialized. Each upload is ordered + * behind both any previous work in \p hStream and any previous launches of \p + * hGraphExec. Uses memory cached by \p stream to back the allocations owned by + * \p hGraphExec. + * + * \param hGraphExec - Executable graph to upload + * \param hStream - Stream in which to upload the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraphExec, CUstream hStream); + +/** + * \brief Launches an executable graph in a stream + * + * Executes \p hGraphExec in \p hStream. Only one instance of \p hGraphExec may + * be executing at a time. Each launch is ordered behind both any previous work + * in \p hStream and any previous launches of \p hGraphExec. To execute a graph + * concurrently, it must be instantiated multiple times into multiple executable + * graphs. + * + * If any allocations created by \p hGraphExec remain unfreed (from a previous + * launch) and \p hGraphExec was not instantiated with + * ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, the launch will fail with + * ::CUDA_ERROR_INVALID_VALUE. + * + * \param hGraphExec - Executable graph to launch + * \param hStream - Stream in which to launch the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphUpload, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream); + +/** + * \brief Destroys an executable graph + * + * Destroys the executable graph specified by \p hGraphExec, as well + * as all of its executable nodes. If the executable graph is + * in-flight, it will not be terminated, but rather freed + * asynchronously on completion. + * + * \param hGraphExec - Executable graph to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphUpload, + * ::cuGraphLaunch + */ +CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec); + +/** + * \brief Destroys a graph + * + * Destroys the graph specified by \p hGraph, as well as all of its nodes. + * + * \param hGraph - Graph to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); + +/** + * \brief Check whether an executable graph can be updated with a graph and + * perform the update if possible + * + * Updates the node parameters in the instantiated graph specified by \p + * hGraphExec with the node parameters in a topologically identical graph + * specified by \p hGraph. + * + * Limitations: + * + * - Kernel nodes: + * - The owning context of the function cannot change. + * - A node whose function originally did not use CUDA dynamic parallelism + * cannot be updated to a function which uses CDP. + * - A node whose function originally did not make device-side update calls + * cannot be updated to a function which makes device-side update calls. + * - A cooperative node cannot be updated to a non-cooperative node, and + * vice-versa. + * - If the graph was instantiated with + * CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY, the priority attribute cannot + * change. Equality is checked on the originally requested priority values, + * before they are clamped to the device's supported range. + * - If \p hGraphExec was not instantiated for device launch, a node whose + * function originally did not use device-side cudaGraphLaunch() cannot be + * updated to a function which uses device-side cudaGraphLaunch() unless the + * node resides on the same context as nodes which contained such calls at + * instantiate-time. If no such calls were present at instantiation, these + * updates cannot be performed at all. + * - Neither \p hGraph nor \p hGraphExec may contain device-updatable kernel + * nodes. + * - Memset and memcpy nodes: + * - The CUDA device(s) to which the operand(s) was allocated/mapped cannot + * change. + * - The source/destination memory must be allocated from the same contexts as + * the original source/destination memory. + * - Only 1D memsets can be changed. + * - Additional memcpy node restrictions: + * - Changing either the source or destination memory type(i.e. + * CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_ARRAY, etc.) is not supported. + * - External semaphore wait nodes and record nodes: + * - Changing the number of semaphores is not supported. + * - Conditional nodes: + * - Changing node parameters is not supported. + * - Changing parameters of nodes within the conditional body graph is subject + * to the rules above. + * - Conditional handle flags and default values are updated as part of the + * graph update. + * + * Note: The API may add further restrictions in future releases. The return + * code should always be checked. + * + * cuGraphExecUpdate sets the result member of \p resultInfo to + * CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED under the following conditions: + * - The count of nodes directly in \p hGraphExec and \p hGraph differ, in which + * case resultInfo->errorNode is set to NULL. + * - \p hGraph has more exit nodes than \p hGraph, in which case + * resultInfo->errorNode is set to one of the exit nodes in hGraph. + * - A node in \p hGraph has a different number of dependencies than the node + * from \p hGraphExec it is paired with, in which case resultInfo->errorNode is + * set to the node from \p hGraph. + * - A node in \p hGraph has a dependency that does not match with the + * corresponding dependency of the paired node from \p hGraphExec. + * resultInfo->errorNode will be set to the node from \p hGraph. + * resultInfo->errorFromNode will be set to the mismatched dependency. The + * dependencies are paired based on edge order and a dependency does not match + * when the nodes are already paired based on other edges examined in the graph. + * + * cuGraphExecUpdate sets the result member of \p resultInfo to: + * - CU_GRAPH_EXEC_UPDATE_ERROR if passed an invalid value. + * - CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED if the graph topology changed + * - CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED if the type of a node changed, + * in which case \p hErrorNode_out is set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE if the function + * changed in an unsupported way(see note above), in which case \p + * hErrorNode_out is set to the node from \p hGraph + * - CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED if any parameters to a node + * changed in a way that is not supported, in which case \p hErrorNode_out is + * set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED if any attributes of a node + * changed in a way that is not supported, in which case \p hErrorNode_out is + * set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED if something about a node is + * unsupported, like the node's type or configuration, in which case \p + * hErrorNode_out is set to the node from \p hGraph + * + * If the update fails for a reason not listed above, the result member of \p + * resultInfo will be set to CU_GRAPH_EXEC_UPDATE_ERROR. If the update succeeds, + * the result member will be set to CU_GRAPH_EXEC_UPDATE_SUCCESS. + * + * cuGraphExecUpdate returns CUDA_SUCCESS when the updated was performed + * successfully. It returns CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE if the graph + * update was not performed because it included changes which violated + * constraints specific to instantiated graph update. + * + * \param hGraphExec The instantiated graph to be updated + * \param hGraph The graph containing the updated parameters + * \param resultInfo the error info structure + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, + CUgraphExecUpdateResultInfo *resultInfo); + +/** + * \brief Copies attributes from source node to destination node. + * + * Copies attributes from source node \p src to destination node \p dst. + * Both node must have the same context. + * + * \param[out] dst Destination node + * \param[in] src Source node + * For list of attributes see ::CUkernelNodeAttrID + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeCopyAttributes(CUgraphNode dst, + CUgraphNode src); + +/** + * \brief Queries node attribute. + * + * Queries attribute \p attr from node \p hNode and stores it in corresponding + * member of \p value_out. + * + * \param[in] hNode + * \param[in] attr + * \param[out] value_out + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI +cuGraphKernelNodeGetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + CUkernelNodeAttrValue *value_out); + +/** + * \brief Sets node attribute. + * + * Sets attribute \p attr on node \p hNode from corresponding attribute of + * \p value. + * + * \param[out] hNode + * \param[in] attr + * \param[out] value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI +cuGraphKernelNodeSetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + const CUkernelNodeAttrValue *value); + +/** + * \brief Write a DOT file describing graph structure + * + * Using the provided \p hGraph, write to \p path a DOT formatted description of + * the graph. By default this includes the graph topology, node types, node id, + * kernel names and memcpy direction. \p flags can be specified to write more + * detailed information about each node type such as parameter values, kernel + * attributes, node and function handles. + * + * \param hGraph - The graph to create a DOT file from + * \param path - The path to write the DOT file to + * \param flags - Flags from CUgraphDebugDot_flags for specifying which + * additional node information to write + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OPERATING_SYSTEM + */ +CUresult CUDAAPI cuGraphDebugDotPrint(CUgraph hGraph, const char *path, + unsigned int flags); + +/** + * \brief Create a user object + * + * Create a user object with the specified destructor callback and initial + * reference count. The initial references are owned by the caller. + * + * Destructor callbacks cannot make CUDA API calls and should avoid blocking + * behavior, as they are executed by a shared internal thread. Another thread + * may be signaled to perform such actions, if it does not block forward + * progress of tasks scheduled through CUDA. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information + * on user objects. + * + * \param object_out - Location to return the user object handle + * \param ptr - The pointer to pass to the destroy function + * \param destroy - Callback to free the user object when it is no + * longer in use \param initialRefcount - The initial refcount to create the + * object with, typically 1. The initial references are owned by the calling + * thread. \param flags - Currently it is required to pass + * ::CU_USER_OBJECT_NO_DESTRUCTOR_SYNC, which is the only defined flag. This + * indicates that the destroy callback cannot be waited on by any CUDA API. + * Users requiring synchronization of the callback should signal its completion + * manually. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectCreate(CUuserObject *object_out, void *ptr, + CUhostFn destroy, + unsigned int initialRefcount, + unsigned int flags); + +/** + * \brief Retain a reference to a user object + * + * Retains new references to a user object. The new references are owned by the + * caller. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information + * on user objects. + * + * \param object - The object to retain + * \param count - The number of references to retain, typically 1. Must be + * nonzero and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectRetain(CUuserObject object, unsigned int count); + +/** + * \brief Release a reference to a user object + * + * Releases user object references owned by the caller. The object's destructor + * is invoked if the reference count reaches zero. + * + * It is undefined behavior to release references not owned by the caller, or to + * use a user object handle after all references are released. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information + * on user objects. + * + * \param object - The object to release + * \param count - The number of references to release, typically 1. Must be + * nonzero and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectRelease(CUuserObject object, unsigned int count); + +/** + * \brief Retain a reference to a user object from a graph + * + * Creates or moves user object references that will be owned by a CUDA graph. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information + * on user objects. + * + * \param graph - The graph to associate the reference with + * \param object - The user object to retain a reference for + * \param count - The number of references to add to the graph, typically 1. + * Must be nonzero and not larger than INT_MAX. \param flags - The optional + * flag ::CU_GRAPH_USER_OBJECT_MOVE transfers references from the calling + * thread, rather than create new references. Pass 0 to create new references. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphRetainUserObject(CUgraph graph, CUuserObject object, + unsigned int count, + unsigned int flags); + +/** + * \brief Release a user object reference from a graph + * + * Releases user object references owned by a graph. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information + * on user objects. + * + * \param graph - The graph that will release the reference + * \param object - The user object to release a reference for + * \param count - The number of references to release, typically 1. Must be + * nonzero and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphReleaseUserObject(CUgraph graph, CUuserObject object, + unsigned int count); + +/** + * \brief Adds a node of arbitrary type to a graph + * + * Creates a new node in \p hGraph described by \p nodeParams with \p + * numDependencies dependencies specified via \p dependencies. \p + * numDependencies may be 0. \p dependencies may be null if \p numDependencies + * is 0. \p dependencies may not have any duplicate entries. + * + * \p nodeParams is a tagged union. The node type should be specified in the \p + * type field, and type-specific parameters in the corresponding union member. + * All unused bytes - that is, \p reserved0 and all bytes past the utilized + * union member - must be set to zero. It is recommended to use brace + * initialization or memset to ensure all bytes are initialized. + * + * Note that for some node types, \p nodeParams may contain "out parameters" + * which are modified during the call, such as \p nodeParams->alloc.dptr. + * + * A handle to the new node will be returned in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Specification of the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphNodeSetParams, + * ::cuGraphExecNodeSetParams + */ +CUresult CUDAAPI cuGraphAddNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUgraphNodeParams *nodeParams); + +/** + * \brief Adds a node of arbitrary type to a graph (12.3+) + * + * Creates a new node in \p hGraph described by \p nodeParams with \p + * numDependencies dependencies specified via \p dependencies. \p + * numDependencies may be 0. \p dependencies may be null if \p numDependencies + * is 0. \p dependencies may not have any duplicate entries. + * + * \p nodeParams is a tagged union. The node type should be specified in the \p + * type field, and type-specific parameters in the corresponding union member. + * All unused bytes - that is, \p reserved0 and all bytes past the utilized + * union member - must be set to zero. It is recommended to use brace + * initialization or memset to ensure all bytes are initialized. + * + * Note that for some node types, \p nodeParams may contain "out parameters" + * which are modified during the call, such as \p nodeParams->alloc.dptr. + * + * A handle to the new node will be returned in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param dependencyData - Optional edge data for the dependencies. If NULL, + * the data is assumed to be default (zeroed) for all dependencies. \param + * numDependencies - Number of dependencies \param nodeParams - + * Specification of the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphNodeSetParams, + * ::cuGraphExecNodeSetParams + */ +CUresult CUDAAPI cuGraphAddNode_v2(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + const CUgraphEdgeData *dependencyData, + size_t numDependencies, + CUgraphNodeParams *nodeParams); + +/** + * \brief Update's a graph node's parameters + * + * Sets the parameters of graph node \p hNode to \p nodeParams. The node type + * specified by \p nodeParams->type must match the type of \p hNode. \p + * nodeParams must be fully initialized and all unused bytes (reserved, padding) + * zeroed. + * + * Modifying parameters is not supported for node types + * CU_GRAPH_NODE_TYPE_MEM_ALLOC and CU_GRAPH_NODE_TYPE_MEM_FREE. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphExecNodeSetParams + */ +CUresult CUDAAPI cuGraphNodeSetParams(CUgraphNode hNode, + CUgraphNodeParams *nodeParams); + +/** + * \brief Update's a graph node's parameters in an instantiated graph + * + * Sets the parameters of a node in an executable graph \p hGraphExec. The node + * is identified by the corresponding node \p hNode in the non-executable graph + * from which the executable graph was instantiated. \p hNode must not have been + * removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Allowed changes to parameters on executable graphs are as follows: + * + *
Node typeAllowed changes + *
kernelSee ::cuGraphExecKernelNodeSetParams + *
memcpyAddresses for 1-dimensional copies if allocated in same + * context; see ::cuGraphExecMemcpyNodeSetParams
memsetAddresses for + * 1-dimensional memsets if allocated in same context; see + * ::cuGraphExecMemsetNodeSetParams
hostUnrestricted
child + * graphTopology must match and restrictions apply recursively; see + * ::cuGraphExecUpdate
event waitUnrestricted
event + * recordUnrestricted
external semaphore signalNumber of + * semaphore operations cannot change
external semaphore waitNumber + * of semaphore operations cannot change
memory allocationAPI + * unsupported
memory freeAPI unsupported
batch + * memopsAddresses, values, and operation type for wait operations; see + * ::cuGraphExecBatchMemOpNodeSetParams + *
+ * + * \param hGraphExec - The executable graph in which to update the specified + * node \param hNode - Corresponding node from the graph from which + * graphExec was instantiated \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode, + * ::cuGraphNodeSetParams + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecNodeSetParams(CUgraphExec hGraphExec, + CUgraphNode hNode, + CUgraphNodeParams *nodeParams); + +/** + * \brief Create a conditional handle + * + * Creates a conditional handle associated with \p hGraph. + * + * The conditional handle must be associated with a conditional node in this + * graph or one of its children. + * + * Handles not associated with a conditional node may cause graph instantiation + * to fail. + * + * Handles can only be set from the context with which they are associated. + * + * \param pHandle_out - Pointer used to return the handle to the caller. + * \param hGraph - Graph which will contain the conditional node + * using this handle. \param ctx - Context for the handle and + * associated conditional node. \param defaultLaunchValue - Optional initial + * value for the conditional variable. \param flags - Currently + * must be CU_GRAPH_COND_ASSIGN_DEFAULT or 0. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddNode + */ +CUresult CUDAAPI cuGraphConditionalHandleCreate( + CUgraphConditionalHandle *pHandle_out, CUgraph hGraph, CUcontext ctx, + unsigned int defaultLaunchValue, unsigned int flags); + +/** @} */ /* END CUDA_GRAPH */ + +/** + * \defgroup CUDA_OCCUPANCY Occupancy + * + * ___MANBRIEF___ occupancy calculation functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the occupancy calculation functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns occupancy of a function + * + * Returns in \p *numBlocks the number of the maximum active blocks per + * streaming multiprocessor. + * + * \param numBlocks - Returned occupancy + * \param func - Kernel for which occupancy is calculated + * \param blockSize - Block size the kernel is intended to be launched + * with \param dynamicSMemSize - Per-block dynamic shared memory usage intended, + * in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxActiveBlocksPerMultiprocessor + */ +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize); + +/** + * \brief Returns occupancy of a function + * + * Returns in \p *numBlocks the number of the maximum active blocks per + * streaming multiprocessor. + * + * The \p Flags parameter controls how special cases are handled. The + * valid flags are: + * + * - ::CU_OCCUPANCY_DEFAULT, which maintains the default behavior as + * ::cuOccupancyMaxActiveBlocksPerMultiprocessor; + * + * - ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE, which suppresses the + * default behavior on platform where global caching affects + * occupancy. On such platforms, if caching is enabled, but + * per-block SM resource usage would result in zero occupancy, the + * occupancy calculator will calculate the occupancy as if caching + * is disabled. Setting ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE makes + * the occupancy calculator to return 0 in such cases. More information + * can be found about this feature in the "Unified L1/Texture Cache" + * section of the Maxwell tuning guide. + * + * \param numBlocks - Returned occupancy + * \param func - Kernel for which occupancy is calculated + * \param blockSize - Block size the kernel is intended to be launched + * with \param dynamicSMemSize - Per-block dynamic shared memory usage intended, + * in bytes \param flags - Requested behavior for the occupancy + * calculator + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + */ +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize, + unsigned int flags); + +/** + * \brief Suggest a launch configuration with reasonable occupancy + * + * Returns in \p *blockSize a reasonable block size that can achieve + * the maximum occupancy (or, the maximum number of active warps with + * the fewest blocks per multiprocessor), and in \p *minGridSize the + * minimum grid size to achieve the maximum occupancy. + * + * If \p blockSizeLimit is 0, the configurator will use the maximum + * block size permitted by the device / function instead. + * + * If per-block dynamic shared memory allocation is not needed, the + * user should leave both \p blockSizeToDynamicSMemSize and \p + * dynamicSMemSize as 0. + * + * If per-block dynamic shared memory allocation is needed, then if + * the dynamic shared memory size is constant regardless of block + * size, the size should be passed through \p dynamicSMemSize, and \p + * blockSizeToDynamicSMemSize should be NULL. + * + * Otherwise, if the per-block dynamic shared memory size varies with + * different block sizes, the user needs to provide a unary function + * through \p blockSizeToDynamicSMemSize that computes the dynamic + * shared memory needed by \p func for any given block size. \p + * dynamicSMemSize is ignored. An example signature is: + * + * \code + * // Take block size, returns dynamic shared memory needed + * size_t blockToSmem(int blockSize); + * \endcode + * + * \param minGridSize - Returned minimum grid size needed to achieve the maximum + * occupancy \param blockSize - Returned maximum block size that can achieve + * the maximum occupancy \param func - Kernel for which launch + * configuration is calculated \param blockSizeToDynamicSMemSize - A function + * that calculates how much per-block dynamic shared memory \p func uses based + * on the block size \param dynamicSMemSize - Dynamic shared memory usage + * intended, in bytes \param blockSizeLimit - The maximum block size \p func is + * designed to handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxPotentialBlockSize + */ +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit); + +/** + * \brief Suggest a launch configuration with reasonable occupancy + * + * An extended version of ::cuOccupancyMaxPotentialBlockSize. In + * addition to arguments passed to ::cuOccupancyMaxPotentialBlockSize, + * ::cuOccupancyMaxPotentialBlockSizeWithFlags also takes a \p Flags + * parameter. + * + * The \p Flags parameter controls how special cases are handled. The + * valid flags are: + * + * - ::CU_OCCUPANCY_DEFAULT, which maintains the default behavior as + * ::cuOccupancyMaxPotentialBlockSize; + * + * - ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE, which suppresses the + * default behavior on platform where global caching affects + * occupancy. On such platforms, the launch configurations that + * produces maximal occupancy might not support global + * caching. Setting ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE + * guarantees that the the produced launch configuration is global + * caching compatible at a potential cost of occupancy. More information + * can be found about this feature in the "Unified L1/Texture Cache" + * section of the Maxwell tuning guide. + * + * \param minGridSize - Returned minimum grid size needed to achieve the maximum + * occupancy \param blockSize - Returned maximum block size that can achieve + * the maximum occupancy \param func - Kernel for which launch + * configuration is calculated \param blockSizeToDynamicSMemSize - A function + * that calculates how much per-block dynamic shared memory \p func uses based + * on the block size \param dynamicSMemSize - Dynamic shared memory usage + * intended, in bytes \param blockSizeLimit - The maximum block size \p func is + * designed to handle \param flags - Options + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxPotentialBlockSizeWithFlags + */ +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit, unsigned int flags); + +/** + * \brief Returns dynamic shared memory available per block when launching \p + * numBlocks blocks on SM + * + * Returns in \p *dynamicSmemSize the maximum size of dynamic shared memory to + * allow \p numBlocks blocks per SM. + * + * \param dynamicSmemSize - Returned maximum dynamic shared memory + * \param func - Kernel function for which occupancy is calculated + * \param numBlocks - Number of blocks to fit on SM + * \param blockSize - Size of the blocks + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + */ +CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock( + size_t *dynamicSmemSize, CUfunction func, int numBlocks, int blockSize); + +/** + * \brief Given the kernel function (\p func) and launch configuration + * (\p config), return the maximum cluster size in \p *clusterSize. + * + * The cluster dimensions in \p config are ignored. If func has a required + * cluster size set (see ::cudaFuncGetAttributes / ::cuFuncGetAttribute),\p + * *clusterSize will reflect the required cluster size. + * + * By default this function will always return a value that's portable on + * future hardware. A higher value may be returned if the kernel function + * allows non-portable cluster sizes. + * + * This function will respect the compile time launch bounds. + * + * \param clusterSize - Returned maximum cluster size that can be launched + * for the given kernel function and launch configuration + * \param func - Kernel function for which maximum cluster + * size is calculated + * \param config - Launch configuration for the given kernel function + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaFuncGetAttributes, + * ::cuFuncGetAttribute + */ +CUresult CUDAAPI cuOccupancyMaxPotentialClusterSize( + int *clusterSize, CUfunction func, const CUlaunchConfig *config); + +/** + * \brief Given the kernel function (\p func) and launch configuration + * (\p config), return the maximum number of clusters that could co-exist + * on the target device in \p *numClusters. + * + * If the function has required cluster size already set (see + * ::cudaFuncGetAttributes / ::cuFuncGetAttribute), the cluster size + * from config must either be unspecified or match the required size. + * Without required sizes, the cluster size must be specified in config, + * else the function will return an error. + * + * Note that various attributes of the kernel function may affect occupancy + * calculation. Runtime environment may affect how the hardware schedules + * the clusters, so the calculated occupancy is not guaranteed to be achievable. + * + * \param numClusters - Returned maximum number of clusters that + * could co-exist on the target device + * \param func - Kernel function for which maximum number + * of clusters are calculated + * \param config - Launch configuration for the given kernel function + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CLUSTER_SIZE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaFuncGetAttributes, + * ::cuFuncGetAttribute + */ +CUresult CUDAAPI cuOccupancyMaxActiveClusters(int *numClusters, CUfunction func, + const CUlaunchConfig *config); +/** @} */ /* END CUDA_OCCUPANCY */ + +/** + * \defgroup CUDA_TEXREF_DEPRECATED Texture Reference Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated texture reference management functions of the + * low-level CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated texture reference management + * functions of the low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Binds an array as a texture reference + * + * \deprecated + * + * Binds the CUDA array \p hArray to the texture reference \p hTexRef. Any + * previous address or CUDA array state associated with the texture reference + * is superseded by this function. \p Flags must be set to + * ::CU_TRSA_OVERRIDE_FORMAT. Any CUDA array previously bound to \p hTexRef is + * unbound. + * + * \param hTexRef - Texture reference to bind + * \param hArray - Array to bind + * \param Flags - Options (must be ::CU_TRSA_OVERRIDE_FORMAT) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, + CUarray hArray, + unsigned int Flags); + +/** + * \brief Binds a mipmapped array to a texture reference + * + * \deprecated + * + * Binds the CUDA mipmapped array \p hMipmappedArray to the texture reference \p + * hTexRef. Any previous address or CUDA array state associated with the texture + * reference is superseded by this function. \p Flags must be set to + * ::CU_TRSA_OVERRIDE_FORMAT. Any CUDA array previously bound to \p hTexRef is + * unbound. + * + * \param hTexRef - Texture reference to bind + * \param hMipmappedArray - Mipmapped array to bind + * \param Flags - Options (must be ::CU_TRSA_OVERRIDE_FORMAT) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmappedArray( + CUtexref hTexRef, CUmipmappedArray hMipmappedArray, unsigned int Flags); + +/** + * \brief Binds an address as a texture reference + * + * \deprecated + * + * Binds a linear address range to the texture reference \p hTexRef. Any + * previous address or CUDA array state associated with the texture reference + * is superseded by this function. Any memory previously bound to \p hTexRef + * is unbound. + * + * Since the hardware enforces an alignment requirement on texture base + * addresses, ::cuTexRefSetAddress() passes back a byte offset in + * \p *ByteOffset that must be applied to texture fetches in order to read from + * the desired memory. This offset must be divided by the texel size and + * passed to kernels that read from the texture so they can be applied to the + * ::tex1Dfetch() function. + * + * If the device memory pointer was returned from ::cuMemAlloc(), the offset + * is guaranteed to be 0 and NULL may be passed as the \p ByteOffset parameter. + * + * The total number of elements (or texels) in the linear address range + * cannot exceed ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH. + * The number of elements is computed as (\p bytes / bytesPerElement), + * where bytesPerElement is determined from the data format and number of + * components set using ::cuTexRefSetFormat(). + * + * \param ByteOffset - Returned byte offset + * \param hTexRef - Texture reference to bind + * \param dptr - Device pointer to bind + * \param bytes - Size of memory to bind in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, + CUtexref hTexRef, + CUdeviceptr dptr, + size_t bytes); + +/** + * \brief Binds an address as a 2D texture reference + * + * \deprecated + * + * Binds a linear address range to the texture reference \p hTexRef. Any + * previous address or CUDA array state associated with the texture reference + * is superseded by this function. Any memory previously bound to \p hTexRef + * is unbound. + * + * Using a ::tex2D() function inside a kernel requires a call to either + * ::cuTexRefSetArray() to bind the corresponding texture reference to an + * array, or ::cuTexRefSetAddress2D() to bind the texture reference to linear + * memory. + * + * Function calls to ::cuTexRefSetFormat() cannot follow calls to + * ::cuTexRefSetAddress2D() for the same texture reference. + * + * It is required that \p dptr be aligned to the appropriate hardware-specific + * texture alignment. You can query this value using the device attribute + * ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT. If an unaligned \p dptr is + * supplied, ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \p Pitch has to be aligned to the hardware-specific texture pitch alignment. + * This value can be queried using the device attribute + * ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT. If an unaligned \p Pitch is + * supplied, ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Width and Height, which are specified in elements (or texels), cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH and + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT respectively. + * \p Pitch, which is specified in bytes, cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH. + * + * \param hTexRef - Texture reference to bind + * \param desc - Descriptor of CUDA array + * \param dptr - Device pointer to bind + * \param Pitch - Line pitch in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, + CUdeviceptr dptr, size_t Pitch); + +/** + * \brief Sets the format for a texture reference + * + * \deprecated + * + * Specifies the format of the data to be read by the texture reference + * \p hTexRef. \p fmt and \p NumPackedComponents are exactly analogous to the + * ::Format and ::NumChannels members of the ::CUDA_ARRAY_DESCRIPTOR structure: + * They specify the format of each component and the number of components per + * array element. + * + * \param hTexRef - Texture reference + * \param fmt - Format to set + * \param NumPackedComponents - Number of components per array element + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaCreateChannelDesc + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, + CUarray_format fmt, + int NumPackedComponents); + +/** + * \brief Sets the addressing mode for a texture reference + * + * \deprecated + * + * Specifies the addressing mode \p am for the given dimension \p dim of the + * texture reference \p hTexRef. If \p dim is zero, the addressing mode is + * applied to the first parameter of the functions used to fetch from the + * texture; if \p dim is 1, the second, and so on. ::CUaddress_mode is defined + * as: + * \code + typedef enum CUaddress_mode_enum { + CU_TR_ADDRESS_MODE_WRAP = 0, + CU_TR_ADDRESS_MODE_CLAMP = 1, + CU_TR_ADDRESS_MODE_MIRROR = 2, + CU_TR_ADDRESS_MODE_BORDER = 3 + } CUaddress_mode; + * \endcode + * + * Note that this call has no effect if \p hTexRef is bound to linear memory. + * Also, if the flag, ::CU_TRSF_NORMALIZED_COORDINATES, is not set, the only + * supported address mode is ::CU_TR_ADDRESS_MODE_CLAMP. + * + * \param hTexRef - Texture reference + * \param dim - Dimension + * \param am - Addressing mode to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, + int dim, + CUaddress_mode am); + +/** + * \brief Sets the filtering mode for a texture reference + * + * \deprecated + * + * Specifies the filtering mode \p fm to be used when reading memory through + * the texture reference \p hTexRef. ::CUfilter_mode_enum is defined as: + * + * \code + typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, + CU_TR_FILTER_MODE_LINEAR = 1 + } CUfilter_mode; + * \endcode + * + * Note that this call has no effect if \p hTexRef is bound to linear memory. + * + * \param hTexRef - Texture reference + * \param fm - Filtering mode to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, + CUfilter_mode fm); + +/** + * \brief Sets the mipmap filtering mode for a texture reference + * + * \deprecated + * + * Specifies the mipmap filtering mode \p fm to be used when reading memory + through + * the texture reference \p hTexRef. ::CUfilter_mode_enum is defined as: + * + * \code + typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, + CU_TR_FILTER_MODE_LINEAR = 1 + } CUfilter_mode; + * \endcode + * + * Note that this call has no effect if \p hTexRef is not bound to a mipmapped + array. + * + * \param hTexRef - Texture reference + * \param fm - Filtering mode to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm); + +/** + * \brief Sets the mipmap level bias for a texture reference + * + * \deprecated + * + * Specifies the mipmap level bias \p bias to be added to the specified mipmap + * level when reading memory through the texture reference \p hTexRef. + * + * Note that this call has no effect if \p hTexRef is not bound to a mipmapped + * array. + * + * \param hTexRef - Texture reference + * \param bias - Mipmap level bias + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, + float bias); + +/** + * \brief Sets the mipmap min/max mipmap level clamps for a texture reference + * + * \deprecated + * + * Specifies the min/max mipmap level clamps, \p minMipmapLevelClamp and \p + * maxMipmapLevelClamp respectively, to be used when reading memory through the + * texture reference \p hTexRef. + * + * Note that this call has no effect if \p hTexRef is not bound to a mipmapped + * array. + * + * \param hTexRef - Texture reference + * \param minMipmapLevelClamp - Mipmap min level clamp + * \param maxMipmapLevelClamp - Mipmap max level clamp + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelClamp( + CUtexref hTexRef, float minMipmapLevelClamp, float maxMipmapLevelClamp); + +/** + * \brief Sets the maximum anisotropy for a texture reference + * + * \deprecated + * + * Specifies the maximum anisotropy \p maxAniso to be used when reading memory + * through the texture reference \p hTexRef. + * + * Note that this call has no effect if \p hTexRef is bound to linear memory. + * + * \param hTexRef - Texture reference + * \param maxAniso - Maximum anisotropy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAniso); + +/** + * \brief Sets the border color for a texture reference + * + * \deprecated + * + * Specifies the value of the RGBA color via the \p pBorderColor to the texture + * reference \p hTexRef. The color value supports only float type and holds + * color components in the following sequence: pBorderColor[0] holds 'R' + * component pBorderColor[1] holds 'G' component pBorderColor[2] holds 'B' + * component pBorderColor[3] holds 'A' component + * + * Note that the color values can be set only when the Address mode is set to + * CU_TR_ADDRESS_MODE_BORDER using ::cuTexRefSetAddressMode. + * Applications using integer border color values have to "reinterpret_cast" + * their values to float. + * + * \param hTexRef - Texture reference + * \param pBorderColor - RGBA color + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddressMode, + * ::cuTexRefGetAddressMode, ::cuTexRefGetBorderColor + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, + float *pBorderColor); + +/** + * \brief Sets the flags for a texture reference + * + * \deprecated + * + * Specifies optional flags via \p Flags to specify the behavior of data + * returned through the texture reference \p hTexRef. The valid flags are: + * + * - ::CU_TRSF_READ_AS_INTEGER, which suppresses the default behavior of + * having the texture promote integer data to floating point data in the + * range [0, 1]. Note that texture with 32-bit integer format + * would not be promoted, regardless of whether or not this + * flag is specified; + * - ::CU_TRSF_NORMALIZED_COORDINATES, which suppresses the + * default behavior of having the texture coordinates range + * from [0, Dim) where Dim is the width or height of the CUDA + * array. Instead, the texture coordinates [0, 1.0) reference + * the entire breadth of the array dimension; + * - ::CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, which disables any trilinear + * filtering optimizations. Trilinear optimizations improve texture filtering + * performance by allowing bilinear filtering on textures in scenarios where + * it can closely approximate the expected results. + * + * \param hTexRef - Texture reference + * \param Flags - Optional flags to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, + unsigned int Flags); + +/** + * \brief Gets the address associated with a texture reference + * + * \deprecated + * + * Returns in \p *pdptr the base address bound to the texture reference + * \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference + * is not bound to any device memory range. + * + * \param pdptr - Returned device address + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, + CUtexref hTexRef); + +/** + * \brief Gets the array bound to a texture reference + * + * \deprecated + * + * Returns in \p *phArray the CUDA array bound to the texture reference + * \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference + * is not bound to any CUDA array. + * + * \param phArray - Returned array + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, + CUtexref hTexRef); + +/** + * \brief Gets the mipmapped array bound to a texture reference + * + * \deprecated + * + * Returns in \p *phMipmappedArray the CUDA mipmapped array bound to the texture + * reference \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture + * reference is not bound to any CUDA mipmapped array. + * + * \param phMipmappedArray - Returned mipmapped array + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, CUtexref hTexRef); + +/** + * \brief Gets the addressing mode used by a texture reference + * + * \deprecated + * + * Returns in \p *pam the addressing mode corresponding to the + * dimension \p dim of the texture reference \p hTexRef. Currently, the only + * valid value for \p dim are 0 and 1. + * + * \param pam - Returned addressing mode + * \param hTexRef - Texture reference + * \param dim - Dimension + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, + CUtexref hTexRef, + int dim); + +/** + * \brief Gets the filter-mode used by a texture reference + * + * \deprecated + * + * Returns in \p *pfm the filtering mode of the texture reference + * \p hTexRef. + * + * \param pfm - Returned filtering mode + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, + CUtexref hTexRef); + +/** + * \brief Gets the format used by a texture reference + * + * \deprecated + * + * Returns in \p *pFormat and \p *pNumChannels the format and number + * of components of the CUDA array bound to the texture reference \p hTexRef. + * If \p pFormat or \p pNumChannels is NULL, it will be ignored. + * + * \param pFormat - Returned format + * \param pNumChannels - Returned number of components + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, + int *pNumChannels, + CUtexref hTexRef); + +/** + * \brief Gets the mipmap filtering mode for a texture reference + * + * \deprecated + * + * Returns the mipmap filtering mode in \p pfm that's used when reading memory + * through the texture reference \p hTexRef. + * + * \param pfm - Returned mipmap filtering mode + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); + +/** + * \brief Gets the mipmap level bias for a texture reference + * + * \deprecated + * + * Returns the mipmap level bias in \p pBias that's added to the specified + * mipmap level when reading memory through the texture reference \p hTexRef. + * + * \param pbias - Returned mipmap level bias + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, + CUtexref hTexRef); + +/** + * \brief Gets the min/max mipmap level clamps for a texture reference + * + * \deprecated + * + * Returns the min/max mipmap level clamps in \p pminMipmapLevelClamp and \p + * pmaxMipmapLevelClamp that's used when reading memory through the texture + * reference \p hTexRef. + * + * \param pminMipmapLevelClamp - Returned mipmap min level clamp + * \param pmaxMipmapLevelClamp - Returned mipmap max level clamp + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapLevelClamp( + float *pminMipmapLevelClamp, float *pmaxMipmapLevelClamp, CUtexref hTexRef); + +/** + * \brief Gets the maximum anisotropy for a texture reference + * + * \deprecated + * + * Returns the maximum anisotropy in \p pmaxAniso that's used when reading + * memory through the texture reference \p hTexRef. + * + * \param pmaxAniso - Returned maximum anisotropy + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, + CUtexref hTexRef); + +/** + * \brief Gets the border color used by a texture reference + * + * \deprecated + * + * Returns in \p pBorderColor, values of the RGBA color used by + * the texture reference \p hTexRef. + * The color value is of type float and holds color components in + * the following sequence: + * pBorderColor[0] holds 'R' component + * pBorderColor[1] holds 'G' component + * pBorderColor[2] holds 'B' component + * pBorderColor[3] holds 'A' component + * + * \param hTexRef - Texture reference + * \param pBorderColor - Returned Type and Value of RGBA color + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddressMode, + * ::cuTexRefSetAddressMode, ::cuTexRefSetBorderColor + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, + CUtexref hTexRef); + +/** + * \brief Gets the flags used by a texture reference + * + * \deprecated + * + * Returns in \p *pFlags the flags of the texture reference \p hTexRef. + * + * \param pFlags - Returned flags + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, + CUtexref hTexRef); + +/** + * \brief Creates a texture reference + * + * \deprecated + * + * Creates a texture reference and returns its handle in \p *pTexRef. Once + * created, the application must call ::cuTexRefSetArray() or + * ::cuTexRefSetAddress() to associate the reference with allocated memory. + * Other texture reference functions are used to specify the format and + * interpretation (addressing, filtering, etc.) to be used when the memory is + * read through this texture reference. + * + * \param pTexRef - Returned texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefDestroy + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefCreate(CUtexref *pTexRef); + +/** + * \brief Destroys a texture reference + * + * \deprecated + * + * Destroys the texture reference specified by \p hTexRef. + * + * \param hTexRef - Texture reference to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefCreate + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef); + +/** @} */ /* END CUDA_TEXREF_DEPRECATED */ + +/** + * \defgroup CUDA_SURFREF_DEPRECATED Surface Reference Management [DEPRECATED] + * + * ___MANBRIEF___ surface reference management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the surface reference management functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Sets the CUDA array for a surface reference. + * + * \deprecated + * + * Sets the CUDA array \p hArray to be read and written by the surface reference + * \p hSurfRef. Any previous CUDA array state associated with the surface + * reference is superseded by this function. \p Flags must be set to 0. + * The ::CUDA_ARRAY3D_SURFACE_LDST flag must have been set for the CUDA array. + * Any CUDA array previously bound to \p hSurfRef is unbound. + + * \param hSurfRef - Surface reference handle + * \param hArray - CUDA array handle + * \param Flags - set to 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuModuleGetSurfRef, + * ::cuSurfRefGetArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, + CUarray hArray, + unsigned int Flags); + +/** + * \brief Passes back the CUDA array bound to a surface reference. + * + * \deprecated + * + * Returns in \p *phArray the CUDA array bound to the surface reference + * \p hSurfRef, or returns ::CUDA_ERROR_INVALID_VALUE if the surface reference + * is not bound to any CUDA array. + + * \param phArray - Surface reference handle + * \param hSurfRef - Surface reference handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuModuleGetSurfRef, ::cuSurfRefSetArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, + CUsurfref hSurfRef); + +/** @} */ /* END CUDA_SURFREF_DEPRECATED */ + +/** + * \defgroup CUDA_TEXOBJECT Texture Object Management + * + * ___MANBRIEF___ texture object management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the texture object management functions of the + * low-level CUDA driver application programming interface. The texture + * object API is only supported on devices of compute capability 3.0 or higher. + * + * @{ + */ + +/** + * \brief Creates a texture object + * + * Creates a texture object and returns it in \p pTexObject. \p pResDesc + describes + * the data to texture from. \p pTexDesc describes how the data should be + sampled. + * \p pResViewDesc is an optional argument that specifies an alternate format + for + * the data described by \p pResDesc, and also describes the subresource region + * to restrict access to when texturing. \p pResViewDesc can only be specified + if + * the type of resource is a CUDA array or a CUDA mipmapped array. + * + * Texture objects are only supported on devices of compute capability 3.0 or + higher. + * Additionally, a texture object is an opaque value, and, as such, should only + be + * accessed through CUDA API calls. + * + * The ::CUDA_RESOURCE_DESC structure is defined as: + * \code + typedef struct CUDA_RESOURCE_DESC_st + { + CUresourcetype resType; + + union { + struct { + CUarray hArray; + } array; + struct { + CUmipmappedArray hMipmappedArray; + } mipmap; + struct { + CUdeviceptr devPtr; + CUarray_format format; + unsigned int numChannels; + size_t sizeInBytes; + } linear; + struct { + CUdeviceptr devPtr; + CUarray_format format; + unsigned int numChannels; + size_t width; + size_t height; + size_t pitchInBytes; + } pitch2D; + } res; + + unsigned int flags; + } CUDA_RESOURCE_DESC; + + * \endcode + * where: + * - ::CUDA_RESOURCE_DESC::resType specifies the type of resource to texture + from. + * CUresourceType is defined as: + * \code + typedef enum CUresourcetype_enum { + CU_RESOURCE_TYPE_ARRAY = 0x00, + CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, + CU_RESOURCE_TYPE_LINEAR = 0x02, + CU_RESOURCE_TYPE_PITCH2D = 0x03 + } CUresourcetype; + * \endcode + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_ARRAY, + ::CUDA_RESOURCE_DESC::res::array::hArray + * must be set to a valid CUDA array handle. + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to + ::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY, + ::CUDA_RESOURCE_DESC::res::mipmap::hMipmappedArray + * must be set to a valid CUDA mipmapped array handle. + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_LINEAR, + ::CUDA_RESOURCE_DESC::res::linear::devPtr + * must be set to a valid device pointer, that is aligned to + ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT. + * ::CUDA_RESOURCE_DESC::res::linear::format and + ::CUDA_RESOURCE_DESC::res::linear::numChannels + * describe the format of each component and the number of components per array + element. ::CUDA_RESOURCE_DESC::res::linear::sizeInBytes + * specifies the size of the array in bytes. The total number of elements in the + linear address range cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH. The number of elements + is computed as (sizeInBytes / (sizeof(format) * numChannels)). + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_PITCH2D, + ::CUDA_RESOURCE_DESC::res::pitch2D::devPtr + * must be set to a valid device pointer, that is aligned to + ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT. + * ::CUDA_RESOURCE_DESC::res::pitch2D::format and + ::CUDA_RESOURCE_DESC::res::pitch2D::numChannels + * describe the format of each component and the number of components per array + element. ::CUDA_RESOURCE_DESC::res::pitch2D::width + * and ::CUDA_RESOURCE_DESC::res::pitch2D::height specify the width and height + of the array in elements, and cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH and + ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT respectively. + * ::CUDA_RESOURCE_DESC::res::pitch2D::pitchInBytes specifies the pitch between + two rows in bytes and has to be aligned to + * ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT. Pitch cannot exceed + ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH. + * + * - ::flags must be set to zero. + * + * + * The ::CUDA_TEXTURE_DESC struct is defined as + * \code + typedef struct CUDA_TEXTURE_DESC_st { + CUaddress_mode addressMode[3]; + CUfilter_mode filterMode; + unsigned int flags; + unsigned int maxAnisotropy; + CUfilter_mode mipmapFilterMode; + float mipmapLevelBias; + float minMipmapLevelClamp; + float maxMipmapLevelClamp; + } CUDA_TEXTURE_DESC; + * \endcode + * where + * - ::CUDA_TEXTURE_DESC::addressMode specifies the addressing mode for each + dimension of the texture data. ::CUaddress_mode is defined as: + * \code + typedef enum CUaddress_mode_enum { + CU_TR_ADDRESS_MODE_WRAP = 0, + CU_TR_ADDRESS_MODE_CLAMP = 1, + CU_TR_ADDRESS_MODE_MIRROR = 2, + CU_TR_ADDRESS_MODE_BORDER = 3 + } CUaddress_mode; + * \endcode + * This is ignored if ::CUDA_RESOURCE_DESC::resType is + ::CU_RESOURCE_TYPE_LINEAR. Also, if the flag, ::CU_TRSF_NORMALIZED_COORDINATES + * is not set, the only supported address mode is ::CU_TR_ADDRESS_MODE_CLAMP. + * + * - ::CUDA_TEXTURE_DESC::filterMode specifies the filtering mode to be used + when fetching from the texture. CUfilter_mode is defined as: + * \code + typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, + CU_TR_FILTER_MODE_LINEAR = 1 + } CUfilter_mode; + * \endcode + * This is ignored if ::CUDA_RESOURCE_DESC::resType is + ::CU_RESOURCE_TYPE_LINEAR. + * + * - ::CUDA_TEXTURE_DESC::flags can be any combination of the following: + * - ::CU_TRSF_READ_AS_INTEGER, which suppresses the default behavior of + * having the texture promote integer data to floating point data in the + * range [0, 1]. Note that texture with 32-bit integer format would not be + * promoted, regardless of whether or not this flag is specified. + * - ::CU_TRSF_NORMALIZED_COORDINATES, which suppresses the default behavior + * of having the texture coordinates range from [0, Dim) where Dim is the + * width or height of the CUDA array. Instead, the texture coordinates + * [0, 1.0) reference the entire breadth of the array dimension; Note that + * for CUDA mipmapped arrays, this flag has to be set. + * - ::CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, which disables any trilinear + * filtering optimizations. Trilinear optimizations improve texture filtering + * performance by allowing bilinear filtering on textures in scenarios where + * it can closely approximate the expected results. + * - ::CU_TRSF_SEAMLESS_CUBEMAP, which enables seamless cube map filtering. + * This flag can only be specified if the underlying resource is a CUDA array + * or a CUDA mipmapped array that was created with the flag + ::CUDA_ARRAY3D_CUBEMAP. + * When seamless cube map filtering is enabled, texture address modes + specified + * by ::CUDA_TEXTURE_DESC::addressMode are ignored. Instead, if the + ::CUDA_TEXTURE_DESC::filterMode + * is set to ::CU_TR_FILTER_MODE_POINT the address mode + ::CU_TR_ADDRESS_MODE_CLAMP + * will be applied for all dimensions. If the ::CUDA_TEXTURE_DESC::filterMode + is + * set to ::CU_TR_FILTER_MODE_LINEAR seamless cube map filtering will be + performed + * when sampling along the cube face borders. + * + * - ::CUDA_TEXTURE_DESC::maxAnisotropy specifies the maximum anisotropy ratio + to be used when doing anisotropic filtering. This value will be + * clamped to the range [1,16]. + * + * - ::CUDA_TEXTURE_DESC::mipmapFilterMode specifies the filter mode when the + calculated mipmap level lies between two defined mipmap levels. + * + * - ::CUDA_TEXTURE_DESC::mipmapLevelBias specifies the offset to be applied to + the calculated mipmap level. + * + * - ::CUDA_TEXTURE_DESC::minMipmapLevelClamp specifies the lower end of the + mipmap level range to clamp access to. + * + * - ::CUDA_TEXTURE_DESC::maxMipmapLevelClamp specifies the upper end of the + mipmap level range to clamp access to. + * + * + * The ::CUDA_RESOURCE_VIEW_DESC struct is defined as + * \code + typedef struct CUDA_RESOURCE_VIEW_DESC_st + { + CUresourceViewFormat format; + size_t width; + size_t height; + size_t depth; + unsigned int firstMipmapLevel; + unsigned int lastMipmapLevel; + unsigned int firstLayer; + unsigned int lastLayer; + } CUDA_RESOURCE_VIEW_DESC; + * \endcode + * where: + * - ::CUDA_RESOURCE_VIEW_DESC::format specifies how the data contained in the + CUDA array or CUDA mipmapped array should + * be interpreted. Note that this can incur a change in size of the texture + data. If the resource view format is a block + * compressed format, then the underlying CUDA array or CUDA mipmapped array + has to have a base of format ::CU_AD_FORMAT_UNSIGNED_INT32. + * with 2 or 4 channels, depending on the block compressed format. For ex., + BC1 and BC4 require the underlying CUDA array to have + * a format of ::CU_AD_FORMAT_UNSIGNED_INT32 with 2 channels. The other BC + formats require the underlying resource to have the same base + * format but with 4 channels. + * + * - ::CUDA_RESOURCE_VIEW_DESC::width specifies the new width of the texture + data. If the resource view format is a block + * compressed format, this value has to be 4 times the original width of the + resource. For non block compressed formats, + * this value has to be equal to that of the original resource. + * + * - ::CUDA_RESOURCE_VIEW_DESC::height specifies the new height of the texture + data. If the resource view format is a block + * compressed format, this value has to be 4 times the original height of the + resource. For non block compressed formats, + * this value has to be equal to that of the original resource. + * + * - ::CUDA_RESOURCE_VIEW_DESC::depth specifies the new depth of the texture + data. This value has to be equal to that of the + * original resource. + * + * - ::CUDA_RESOURCE_VIEW_DESC::firstMipmapLevel specifies the most detailed + mipmap level. This will be the new mipmap level zero. + * For non-mipmapped resources, this value has to be + zero.::CUDA_TEXTURE_DESC::minMipmapLevelClamp and + ::CUDA_TEXTURE_DESC::maxMipmapLevelClamp + * will be relative to this value. For ex., if the firstMipmapLevel is set to + 2, and a minMipmapLevelClamp of 1.2 is specified, + * then the actual minimum mipmap level clamp will be 3.2. + * + * - ::CUDA_RESOURCE_VIEW_DESC::lastMipmapLevel specifies the least detailed + mipmap level. For non-mipmapped resources, this value + * has to be zero. + * + * - ::CUDA_RESOURCE_VIEW_DESC::firstLayer specifies the first layer index for + layered textures. This will be the new layer zero. + * For non-layered resources, this value has to be zero. + * + * - ::CUDA_RESOURCE_VIEW_DESC::lastLayer specifies the last layer index for + layered textures. For non-layered resources, + * this value has to be zero. + * + * + * \param pTexObject - Texture object to create + * \param pResDesc - Resource descriptor + * \param pTexDesc - Texture descriptor + * \param pResViewDesc - Resource view descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectDestroy, + * ::cudaCreateTextureObject + */ +CUresult CUDAAPI cuTexObjectCreate(CUtexObject *pTexObject, + const CUDA_RESOURCE_DESC *pResDesc, + const CUDA_TEXTURE_DESC *pTexDesc, + const CUDA_RESOURCE_VIEW_DESC *pResViewDesc); + +/** + * \brief Destroys a texture object + * + * Destroys the texture object specified by \p texObject. + * + * \param texObject - Texture object to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaDestroyTextureObject + */ +CUresult CUDAAPI cuTexObjectDestroy(CUtexObject texObject); + +/** + * \brief Returns a texture object's resource descriptor + * + * Returns the resource descriptor for the texture object specified by \p + * texObject. + * + * \param pResDesc - Resource descriptor + * \param texObject - Texture object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaGetTextureObjectResourceDesc, + */ +CUresult CUDAAPI cuTexObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUtexObject texObject); + +/** + * \brief Returns a texture object's texture descriptor + * + * Returns the texture descriptor for the texture object specified by \p + * texObject. + * + * \param pTexDesc - Texture descriptor + * \param texObject - Texture object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaGetTextureObjectTextureDesc + */ +CUresult CUDAAPI cuTexObjectGetTextureDesc(CUDA_TEXTURE_DESC *pTexDesc, + CUtexObject texObject); + +/** + * \brief Returns a texture object's resource view descriptor + * + * Returns the resource view descriptor for the texture object specified by \p + * texObject. If no resource view was set for \p texObject, the + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param pResViewDesc - Resource view descriptor + * \param texObject - Texture object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaGetTextureObjectResourceViewDesc + */ +CUresult CUDAAPI cuTexObjectGetResourceViewDesc( + CUDA_RESOURCE_VIEW_DESC *pResViewDesc, CUtexObject texObject); + +/** @} */ /* END CUDA_TEXOBJECT */ + +/** + * \defgroup CUDA_SURFOBJECT Surface Object Management + * + * ___MANBRIEF___ surface object management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the surface object management functions of the + * low-level CUDA driver application programming interface. The surface + * object API is only supported on devices of compute capability 3.0 or higher. + * + * @{ + */ + +/** + * \brief Creates a surface object + * + * Creates a surface object and returns it in \p pSurfObject. \p pResDesc + * describes the data to perform surface load/stores on. + * ::CUDA_RESOURCE_DESC::resType must be + * ::CU_RESOURCE_TYPE_ARRAY and ::CUDA_RESOURCE_DESC::res::array::hArray + * must be set to a valid CUDA array handle. ::CUDA_RESOURCE_DESC::flags must be + * set to zero. + * + * Surface objects are only supported on devices of compute capability 3.0 or + * higher. Additionally, a surface object is an opaque value, and, as such, + * should only be accessed through CUDA API calls. + * + * \param pSurfObject - Surface object to create + * \param pResDesc - Resource descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuSurfObjectDestroy, + * ::cudaCreateSurfaceObject + */ +CUresult CUDAAPI cuSurfObjectCreate(CUsurfObject *pSurfObject, + const CUDA_RESOURCE_DESC *pResDesc); + +/** + * \brief Destroys a surface object + * + * Destroys the surface object specified by \p surfObject. + * + * \param surfObject - Surface object to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuSurfObjectCreate, + * ::cudaDestroySurfaceObject + */ +CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject); + +/** + * \brief Returns a surface object's resource descriptor + * + * Returns the resource descriptor for the surface object specified by \p + * surfObject. + * + * \param pResDesc - Resource descriptor + * \param surfObject - Surface object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuSurfObjectCreate, + * ::cudaGetSurfaceObjectResourceDesc + */ +CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUsurfObject surfObject); + +/** @} */ /* END CUDA_SURFOBJECT */ + +/** + * \defgroup CUDA_TENSOR_MEMORY Tensor Map Object Management + * + * ___MANBRIEF___ tensor map object management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the tensor map object management functions of the + * low-level CUDA driver application programming interface. The tensor + * core API is only supported on devices of compute capability 9.0 or higher. + * + * @{ + */ + +/** + * \brief Create a tensor map descriptor object representing tiled memory region + * + * Creates a descriptor for Tensor Memory Access (TMA) object specified + * by the parameters describing a tiled region and returns it in \p tensorMap. + * + * Tensor map objects are only supported on devices of compute capability 9.0 or + higher. + * Additionally, a tensor map object is an opaque value, and, as such, should + only be + * accessed through CUDA API calls. + * + * The parameters passed are bound to the following requirements: + * + * - \p tensorMap address must be aligned to 64 bytes. + * + * - \p tensorDataType has to be an enum from ::CUtensorMapDataType which is + defined as: + * \code + typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, // 1 byte + CU_TENSOR_MAP_DATA_TYPE_UINT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_INT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_INT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ // 4 bytes + } CUtensorMapDataType; + * \endcode + * + * - \p tensorRank must be non-zero and less than or equal to the maximum + supported dimensionality of 5. If \p interleave is not + * ::CU_TENSOR_MAP_INTERLEAVE_NONE, then \p tensorRank must additionally be + greater than or equal to 3. + * + * - \p globalAddress, which specifies the starting address of the memory region + described, must be 32 byte aligned when \p interleave is + * ::CU_TENSOR_MAP_INTERLEAVE_32B and 16 byte aligned otherwise. + * + * - \p globalDim array, which specifies tensor size of each of the \p + tensorRank dimensions, must be non-zero and less than or + * equal to 2^32. + * + * - \p globalStrides array, which specifies tensor stride of each of the lower + \p tensorRank - 1 dimensions in bytes, must be a + * multiple of 16 and less than 2^40. Additionally, the stride must be a + multiple of 32 when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B. + * Each following dimension specified includes previous dimension stride: + * \code + globalStrides[0] = globalDim[0] * elementSizeInBytes(tensorDataType) + + padding[0]; for (i = 1; i < tensorRank - 1; i++) globalStrides[i] = + globalStrides[i – 1] * (globalDim[i] + padding[i]); assert(globalStrides[i] >= + globalDim[i]); + * \endcode + * + * - \p boxDim array, which specifies number of elements to be traversed along + each of the \p tensorRank dimensions, must be non-zero + * and less than or equal to 256. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE, { \p boxDim[0] * + elementSizeInBytes( \p tensorDataType ) } must be a multiple + * of 16 bytes. + * + * - \p elementStrides array, which specifies the iteration step along each of + the \p tensorRank dimensions, must be non-zero and less + * than or equal to 8. Note that when \p interleave is + ::CU_TENSOR_MAP_INTERLEAVE_NONE, the first element of this array is ignored + since + * TMA doesn’t support the stride for dimension zero. + * When all elements of \p elementStrides array is one, \p boxDim specifies the + number of elements to load. However, if the \p elementStrides[i] + * is not equal to one, then TMA loads ceil( \p boxDim[i] / \p + elementStrides[i]) number of elements along i-th dimension. To load N elements + along + * i-th dimension, \p boxDim[i] must be set to N * \p elementStrides[i]. + * + * - \p interleave specifies the interleaved layout of type + ::CUtensorMapInterleave, which is defined as: + * \code + typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B + } CUtensorMapInterleave; + * \endcode + * TMA supports interleaved layouts like NC/8HWC8 where C8 utilizes 16 bytes in + memory assuming 2 byte per channel or NC/16HWC16 where C16 + * uses 32 bytes. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE and \p swizzle is not + ::CU_TENSOR_MAP_SWIZZLE_NONE, the bounding box inner dimension + * (computed as \p boxDim[0] multiplied by element size derived from \p + tensorDataType) must be less than or equal to the swizzle size. + * - CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will + be <= 32. + * - CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will + be <= 64. + * - CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will + be <= 128. + * + * - \p swizzle, which specifies the shared memory bank swizzling pattern, has + to be of type ::CUtensorMapSwizzle which is defined as: + * \code + typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B + } CUtensorMapSwizzle; + * \endcode + * Data are organized in a specific order in global memory; however, this may + not match the order in which the application accesses data + * in shared memory. This difference in data organization may cause bank + conflicts when shared memory is accessed. In order to avoid this + * problem, data can be loaded to shared memory with shuffling across shared + memory banks. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B, \p swizzle must be + ::CU_TENSOR_MAP_SWIZZLE_32B. + * Other interleave modes can have any swizzling pattern. + * + * - \p l2Promotion specifies L2 fetch size which indicates the byte granurality + at which L2 requests is filled from DRAM. It must be of + * type ::CUtensorMapL2promotion, which is defined as: + * \code + typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + } CUtensorMapL2promotion; + * \endcode + * + * - \p oobFill, which indicates whether zero or a special NaN constant should + be used to fill out-of-bound elements, must be of type + * ::CUtensorMapFloatOOBfill which is defined as: + * \code + typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA + } CUtensorMapFloatOOBfill; + * \endcode + * Note that ::CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA can only be + used when \p tensorDataType represents a floating-point data type. + * + * \param tensorMap - Tensor map object to create + * \param tensorDataType - Tensor data type + * \param tensorRank - Dimensionality of tensor + * \param globalAddress - Starting address of memory region described by + tensor + * \param globalDim - Array containing tensor size (number of elements) + along each of the \p tensorRank dimensions + * \param globalStrides - Array containing stride size (in bytes) along each + of the \p tensorRank - 1 dimensions + * \param boxDim - Array containing traversal box size (number of + elements) along each of the \p tensorRank dimensions. Specifies how many + elements to be traversed along each tensor dimension. + * \param elementStrides - Array containing traversal stride in each of the + \p tensorRank dimensions + * \param interleave - Type of interleaved layout the tensor addresses + * \param swizzle - Bank swizzling pattern inside shared memory + * \param l2Promotion - L2 promotion size + * \param oobFill - Indicate whether zero or special NaN constant must + be used to fill out-of-bound elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTensorMapEncodeIm2col, + * ::cuTensorMapReplaceAddress + */ +CUresult CUDAAPI cuTensorMapEncodeTiled( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); + +/** + * \brief Create a tensor map descriptor object representing im2col memory + region + * + * Creates a descriptor for Tensor Memory Access (TMA) object specified + * by the parameters describing a im2col memory layout and returns it in \p + tensorMap. + * + * Tensor map objects are only supported on devices of compute capability 9.0 or + higher. + * Additionally, a tensor map object is an opaque value, and, as such, should + only be + * accessed through CUDA API calls. + * + * The parameters passed are bound to the following requirements: + * + * - \p tensorMap address must be aligned to 64 bytes. + * + * - \p tensorDataType has to be an enum from ::CUtensorMapDataType which is + defined as: + * \code + typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, // 1 byte + CU_TENSOR_MAP_DATA_TYPE_UINT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_INT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_INT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ // 4 bytes + } CUtensorMapDataType; + * \endcode + * + * - \p tensorRank, which specifies the number of tensor dimensions, must be 3, + 4, or 5. + * + * - \p globalAddress, which specifies the starting address of the memory region + described, must be 32 byte aligned when \p interleave is + * ::CU_TENSOR_MAP_INTERLEAVE_32B and 16 byte aligned otherwise. + * + * - \p globalDim array, which specifies tensor size of each of the \p + tensorRank dimensions, must be non-zero and less than or + * equal to 2^32. + * + * - \p globalStrides array, which specifies tensor stride of each of the lower + \p tensorRank - 1 dimensions in bytes, must be a + * multiple of 16 and less than 2^40. Additionally, the stride must be a + multiple of 32 when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B. + * Each following dimension specified includes previous dimension stride: + * \code + globalStrides[0] = globalDim[0] * elementSizeInBytes(tensorDataType) + + padding[0]; for (i = 1; i < tensorRank - 1; i++) globalStrides[i] = + globalStrides[i – 1] * (globalDim[i] + padding[i]); assert(globalStrides[i] >= + globalDim[i]); + * \endcode + * + * - \p pixelBoxLowerCorner array specifies the coordinate offsets {D, H, W} of + the bounding box from top/left/front corner. The number of + * offsets and their precision depend on the tensor dimensionality: + * - When \p tensorRank is 3, one signed offset within range [-32768, 32767] + is supported. + * - When \p tensorRank is 4, two signed offsets each within range [-128, + 127] are supported. + * - When \p tensorRank is 5, three offsets each within range [-16, 15] are + supported. + * + * - \p pixelBoxUpperCorner array specifies the coordinate offsets {D, H, W} of + the bounding box from bottom/right/back corner. The number of + * offsets and their precision depend on the tensor dimensionality: + * - When \p tensorRank is 3, one signed offset within range [-32768, 32767] + is supported. + * - When \p tensorRank is 4, two signed offsets each within range [-128, + 127] are supported. + * - When \p tensorRank is 5, three offsets each within range [-16, 15] are + supported. + * The bounding box specified by \p pixelBoxLowerCorner and \p + pixelBoxUpperCorner must have non-zero area. + * + * - \p channelsPerPixel, which specifies the number of elements which must be + accessed along C dimension, must be less than or equal to 256. + * + * - \p pixelsPerColumn, which specifies the number of elements that must be + accessed along the {N, D, H, W} dimensions, must be less than or + * equal to 1024. + * + * - \p elementStrides array, which specifies the iteration step along each of + the \p tensorRank dimensions, must be non-zero and less + * than or equal to 8. Note that when \p interleave is + ::CU_TENSOR_MAP_INTERLEAVE_NONE, the first element of this array is ignored + since + * TMA doesn’t support the stride for dimension zero. + * When all elements of the \p elementStrides array are one, \p boxDim specifies + the number of elements to load. However, if \p elementStrides[i] + * is not equal to one for some \p i, then TMA loads ceil( \p boxDim[i] / \p + elementStrides[i]) number of elements along i-th dimension. + * To load N elements along i-th dimension, \p boxDim[i] must be set to N * \p + elementStrides[i]. + * + * - \p interleave specifies the interleaved layout of type + ::CUtensorMapInterleave, which is defined as: + * \code + typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B + } CUtensorMapInterleave; + * \endcode + * TMA supports interleaved layouts like NC/8HWC8 where C8 utilizes 16 bytes in + memory assuming 2 byte per channel or NC/16HWC16 where C16 + * uses 32 bytes. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE and \p swizzle is not + ::CU_TENSOR_MAP_SWIZZLE_NONE, the bounding box inner dimension + * (computed as \p boxDim[0] multiplied by element size derived from \p + tensorDataType) must be less than or equal to the swizzle size. + * - CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will + be <= 32. + * - CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will + be <= 64. + * - CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will + be <= 128. + * + * - \p swizzle, which specifies the shared memory bank swizzling pattern, has + to be of type ::CUtensorMapSwizzle which is defined as: + * \code + typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B + } CUtensorMapSwizzle; + * \endcode + * Data are organized in a specific order in global memory; however, this may + not match the order in which the application accesses data + * in shared memory. This difference in data organization may cause bank + conflicts when shared memory is accessed. In order to avoid this + * problem, data can be loaded to shared memory with shuffling across shared + memory banks. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B, \p swizzle must be + ::CU_TENSOR_MAP_SWIZZLE_32B. + * Other interleave modes can have any swizzling pattern. + * + * - \p l2Promotion specifies L2 fetch size which indicates the byte granularity + at which L2 requests are filled from DRAM. It must be of + * type ::CUtensorMapL2promotion, which is defined as: + * \code + typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + } CUtensorMapL2promotion; + * \endcode + * + * - \p oobFill, which indicates whether zero or a special NaN constant should + be used to fill out-of-bound elements, must be of type + * ::CUtensorMapFloatOOBfill which is defined as: + * \code + typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA + } CUtensorMapFloatOOBfill; + * \endcode + * Note that ::CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA can only be + used when \p tensorDataType represents a floating-point data type. + * + * \param tensorMap - Tensor map object to create + * \param tensorDataType - Tensor data type + * \param tensorRank - Dimensionality of tensor; must be at least 3 + * \param globalAddress - Starting address of memory region described by + tensor + * \param globalDim - Array containing tensor size (number of + elements) along each of the \p tensorRank dimensions + * \param globalStrides - Array containing stride size (in bytes) along + each of the \p tensorRank - 1 dimensions + * \param pixelBoxLowerCorner - Array containing DHW dimensions of lower box + corner + * \param pixelBoxUpperCorner - Array containing DHW dimensions of upper box + corner + * \param channelsPerPixel - Number of channels per pixel + * \param pixelsPerColumn - Number of pixels per column + * \param elementStrides - Array containing traversal stride in each of + the \p tensorRank dimensions + * \param interleave - Type of interleaved layout the tensor + addresses + * \param swizzle - Bank swizzling pattern inside shared memory + * \param l2Promotion - L2 promotion size + * \param oobFill - Indicate whether zero or special NaN constant + will be used to fill out-of-bound elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTensorMapEncodeTiled, + * ::cuTensorMapReplaceAddress + */ +CUresult CUDAAPI cuTensorMapEncodeIm2col( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const int *pixelBoxLowerCorner, + const int *pixelBoxUpperCorner, cuuint32_t channelsPerPixel, + cuuint32_t pixelsPerColumn, const cuuint32_t *elementStrides, + CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + +/** + * \brief Modify an existing tensor map descriptor with an updated global + * address + * + * Modifies the descriptor for Tensor Memory Access (TMA) object passed in \p + * tensorMap with an updated \p globalAddress. + * + * Tensor map objects are only supported on devices of compute capability 9.0 or + * higher. Additionally, a tensor map object is an opaque value, and, as such, + * should only be accessed through CUDA API calls. + * + * \param tensorMap - Tensor map object to modify + * \param globalAddress - Starting address of memory region described by + * tensor, must follow previous alignment requirements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTensorMapEncodeTiled, + * ::cuTensorMapEncodeIm2col + */ +CUresult CUDAAPI cuTensorMapReplaceAddress(CUtensorMap *tensorMap, + void *globalAddress); + +/** @} */ +/* END CUDA_TENSOR_MEMORY */ + +/** + * \defgroup CUDA_PEER_ACCESS Peer Context Memory Access + * + * ___MANBRIEF___ direct peer context memory access functions of the low-level + * CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the direct peer context memory access functions + * of the low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Queries if a device may directly access a peer device's memory. + * + * Returns in \p *canAccessPeer a value of 1 if contexts on \p dev are capable + * of directly accessing memory from contexts on \p peerDev and 0 otherwise. If + * direct access of \p peerDev from \p dev is possible, then access may be + * enabled on two specific contexts by calling ::cuCtxEnablePeerAccess(). + * + * \param canAccessPeer - Returned access capability + * \param dev - Device from which allocations on \p peerDev are to + * be directly accessed. + * \param peerDev - Device on which the allocations to be directly + * accessed by \p dev reside. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuCtxEnablePeerAccess, + * ::cuCtxDisablePeerAccess, + * ::cudaDeviceCanAccessPeer + */ +CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, + CUdevice peerDev); + +/** + * \brief Enables direct access to memory allocations in a peer context. + * + * If both the current context and \p peerContext are on devices which support + * unified addressing (as may be queried using + * ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING) and same major compute capability, + * then on success all allocations from \p peerContext will immediately be + * accessible by the current context. See \ref CUDA_UNIFIED for additional + * details. + * + * Note that access granted by this call is unidirectional and that in order to + * access memory from the current context in \p peerContext, a separate + * symmetric call to ::cuCtxEnablePeerAccess() is required. + * + * Note that there are both device-wide and system-wide limitations per system + * configuration, as noted in the CUDA Programming Guide under the section + * "Peer-to-Peer Memory Access". + * + * Returns ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED if ::cuDeviceCanAccessPeer() + * indicates that the ::CUdevice of the current context cannot directly access + * memory from the ::CUdevice of \p peerContext. + * + * Returns ::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED if direct access of + * \p peerContext from the current context has already been enabled. + * + * Returns ::CUDA_ERROR_TOO_MANY_PEERS if direct peer access is not possible + * because hardware resources required for peer access have been exhausted. + * + * Returns ::CUDA_ERROR_INVALID_CONTEXT if there is no current context, \p + * peerContext is not a valid context, or if the current context is \p + * peerContext. + * + * Returns ::CUDA_ERROR_INVALID_VALUE if \p Flags is not 0. + * + * \param peerContext - Peer context to enable direct access to from the current + * context \param Flags - Reserved for future use and must be set to 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED, + * ::CUDA_ERROR_TOO_MANY_PEERS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuDeviceCanAccessPeer, + * ::cuCtxDisablePeerAccess, + * ::cudaDeviceEnablePeerAccess + */ +CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, + unsigned int Flags); + +/** + * \brief Disables direct access to memory allocations in a peer context and + * unregisters any registered allocations. + * + Returns ::CUDA_ERROR_PEER_ACCESS_NOT_ENABLED if direct peer access has + * not yet been enabled from \p peerContext to the current context. + * + * Returns ::CUDA_ERROR_INVALID_CONTEXT if there is no current context, or if + * \p peerContext is not a valid context. + * + * \param peerContext - Peer context to disable direct access to + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_PEER_ACCESS_NOT_ENABLED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * \notefnerr + * + * \sa + * ::cuDeviceCanAccessPeer, + * ::cuCtxEnablePeerAccess, + * ::cudaDeviceDisablePeerAccess + */ +CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext); + +/** + * \brief Queries attributes of the link between two devices. + * + * Returns in \p *value the value of the requested attribute \p attrib of the + * link between \p srcDevice and \p dstDevice. The supported attributes are: + * - ::CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK: A relative value indicating the + * performance of the link between two devices. + * - ::CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED P2P: 1 if P2P Access is enable. + * - ::CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED: 1 if Atomic operations + * over the link are supported. + * - ::CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED: 1 if cudaArray can + * be accessed over the link. + * + * Returns ::CUDA_ERROR_INVALID_DEVICE if \p srcDevice or \p dstDevice are not + * valid or if they represent the same device. + * + * Returns ::CUDA_ERROR_INVALID_VALUE if \p attrib is not valid or if \p value + * is a null pointer. + * + * \param value - Returned value of the requested attribute + * \param attrib - The requested attribute of the link between \p + * srcDevice and \p dstDevice. \param srcDevice - The source device of the + * target link. \param dstDevice - The destination device of the target + * link. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuCtxEnablePeerAccess, + * ::cuCtxDisablePeerAccess, + * ::cuDeviceCanAccessPeer, + * ::cudaDeviceGetP2PAttribute + */ +CUresult CUDAAPI cuDeviceGetP2PAttribute(int *value, + CUdevice_P2PAttribute attrib, + CUdevice srcDevice, + CUdevice dstDevice); + +/** @} */ /* END CUDA_PEER_ACCESS */ + +/** + * \defgroup CUDA_GRAPHICS Graphics Interoperability + * + * ___MANBRIEF___ graphics interoperability functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the graphics interoperability functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Unregisters a graphics resource for access by CUDA + * + * Unregisters the graphics resource \p resource so it is not accessible by + * CUDA unless registered again. + * + * If \p resource is invalid then ::CUDA_ERROR_INVALID_HANDLE is + * returned. + * + * \param resource - Resource to unregister + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cuGraphicsD3D9RegisterResource, + * ::cuGraphicsD3D10RegisterResource, + * ::cuGraphicsD3D11RegisterResource, + * ::cuGraphicsGLRegisterBuffer, + * ::cuGraphicsGLRegisterImage, + * ::cudaGraphicsUnregisterResource + */ +CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource); + +/** + * \brief Get an array through which to access a subresource of a mapped + * graphics resource. + * + * Returns in \p *pArray an array through which the subresource of the mapped + * graphics resource \p resource which corresponds to array index \p arrayIndex + * and mipmap level \p mipLevel may be accessed. The value set in \p *pArray + * may change every time that \p resource is mapped. + * + * If \p resource is not a texture then it cannot be accessed via an array and + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY is returned. + * If \p arrayIndex is not a valid array index for \p resource then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * If \p mipLevel is not a valid mipmap level for \p resource then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * If \p resource is not mapped then ::CUDA_ERROR_NOT_MAPPED is returned. + * + * \param pArray - Returned array through which a subresource of \p + * resource may be accessed \param resource - Mapped resource to access + * \param arrayIndex - Array index for array textures or cubemap face + * index as defined by ::CUarray_cubemap_face for + * cubemap textures for the subresource to access + * \param mipLevel - Mipmap level for the subresource to access + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY + * \notefnerr + * + * \sa + * ::cuGraphicsResourceGetMappedPointer, + * ::cudaGraphicsSubResourceGetMappedArray + */ +CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray( + CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, + unsigned int mipLevel); + +/** + * \brief Get a mipmapped array through which to access a mapped graphics + * resource. + * + * Returns in \p *pMipmappedArray a mipmapped array through which the mapped + * graphics resource \p resource. The value set in \p *pMipmappedArray may + * change every time that \p resource is mapped. + * + * If \p resource is not a texture then it cannot be accessed via a mipmapped + * array and + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY is returned. + * If \p resource is not mapped then ::CUDA_ERROR_NOT_MAPPED is returned. + * + * \param pMipmappedArray - Returned mipmapped array through which \p resource + * may be accessed \param resource - Mapped resource to access + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY + * \notefnerr + * + * \sa + * ::cuGraphicsResourceGetMappedPointer, + * ::cudaGraphicsResourceGetMappedMipmappedArray + */ +CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray( + CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource); + +/** + * \brief Get a device pointer through which to access a mapped graphics + * resource. + * + * Returns in \p *pDevPtr a pointer through which the mapped graphics resource + * \p resource may be accessed. + * Returns in \p pSize the size of the memory in bytes which may be accessed + * from that pointer. The value set in \p pPointer may change every time that \p + * resource is mapped. + * + * If \p resource is not a buffer then it cannot be accessed via a pointer and + * ::CUDA_ERROR_NOT_MAPPED_AS_POINTER is returned. + * If \p resource is not mapped then ::CUDA_ERROR_NOT_MAPPED is returned. + * * + * \param pDevPtr - Returned pointer through which \p resource may be + * accessed \param pSize - Returned size of the buffer accessible starting + * at \p *pPointer \param resource - Mapped resource to access + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_NOT_MAPPED_AS_POINTER + * \notefnerr + * + * \sa + * ::cuGraphicsMapResources, + * ::cuGraphicsSubResourceGetMappedArray, + * ::cudaGraphicsResourceGetMappedPointer + */ +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer( + CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource); + +/** + * \brief Set usage flags for mapping a graphics resource + * + * Set \p flags for mapping the graphics resource \p resource. + * + * Changes to \p flags will take effect the next time \p resource is mapped. + * The \p flags argument may be any of the following: + + * - ::CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE: Specifies no hints about how this + * resource will be used. It is therefore assumed that this resource will be + * read from and written to by CUDA kernels. This is the default value. + * - ::CU_GRAPHICS_MAP_RESOURCE_FLAGS_READONLY: Specifies that CUDA kernels + which + * access this resource will not write to this resource. + * - ::CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITEDISCARD: Specifies that CUDA kernels + * which access this resource will not read from this resource and will + * write over the entire contents of the resource, so none of the data + * previously stored in the resource will be preserved. + * + * If \p resource is presently mapped for access by CUDA then + * ::CUDA_ERROR_ALREADY_MAPPED is returned. + * If \p flags is not one of the above values then ::CUDA_ERROR_INVALID_VALUE is + returned. + * + * \param resource - Registered resource to set flags for + * \param flags - Parameters for resource mapping + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ALREADY_MAPPED + * \notefnerr + * + * \sa + * ::cuGraphicsMapResources, + * ::cudaGraphicsResourceSetMapFlags + */ +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, + unsigned int flags); + +/** + * \brief Map graphics resources for access by CUDA + * + * Maps the \p count graphics resources in \p resources for access by CUDA. + * + * The resources in \p resources may be accessed by CUDA until they + * are unmapped. The graphics API from which \p resources were registered + * should not access any resources while they are mapped by CUDA. If an + * application does so, the results are undefined. + * + * This function provides the synchronization guarantee that any graphics calls + * issued before ::cuGraphicsMapResources() will complete before any subsequent + * CUDA work issued in \p stream begins. + * + * If \p resources includes any duplicate entries then + * ::CUDA_ERROR_INVALID_HANDLE is returned. If any of \p resources are presently + * mapped for access by CUDA then ::CUDA_ERROR_ALREADY_MAPPED is returned. + * + * \param count - Number of resources to map + * \param resources - Resources to map for CUDA usage + * \param hStream - Stream with which to synchronize + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ALREADY_MAPPED, + * ::CUDA_ERROR_UNKNOWN + * \note_null_stream + * \notefnerr + * + * \sa + * ::cuGraphicsResourceGetMappedPointer, + * ::cuGraphicsSubResourceGetMappedArray, + * ::cuGraphicsUnmapResources, + * ::cudaGraphicsMapResources + */ +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream); + +/** + * \brief Unmap graphics resources. + * + * Unmaps the \p count graphics resources in \p resources. + * + * Once unmapped, the resources in \p resources may not be accessed by CUDA + * until they are mapped again. + * + * This function provides the synchronization guarantee that any CUDA work + * issued in \p stream before ::cuGraphicsUnmapResources() will complete before + * any subsequently issued graphics work begins. + * + * + * If \p resources includes any duplicate entries then + * ::CUDA_ERROR_INVALID_HANDLE is returned. If any of \p resources are not + * presently mapped for access by CUDA then ::CUDA_ERROR_NOT_MAPPED is returned. + * + * \param count - Number of resources to unmap + * \param resources - Resources to unmap + * \param hStream - Stream with which to synchronize + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_UNKNOWN + * \note_null_stream + * \notefnerr + * + * \sa + * ::cuGraphicsMapResources, + * ::cudaGraphicsUnmapResources + */ +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream); + +/** @} */ /* END CUDA_GRAPHICS */ + +/** + * \defgroup CUDA_DRIVER_ENTRY_POINT Driver Entry Point Access + * + * ___MANBRIEF___ driver entry point access functions of the low-level CUDA + * driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the driver entry point access functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns the requested driver API function pointer + * + * Returns in \p **pfn the address of the CUDA driver function for the requested + * CUDA version and flags. + * + * The CUDA version is specified as (1000 * major + 10 * minor), so CUDA 11.2 + * should be specified as 11020. For a requested driver symbol, if the specified + * CUDA version is greater than or equal to the CUDA version in which the driver + * symbol was introduced, this API will return the function pointer to the + * corresponding versioned function. + * + * The pointer returned by the API should be cast to a function pointer matching + * the requested driver function's definition in the API header file. The + * function pointer typedef can be picked up from the corresponding typedefs + * header file. For example, cudaTypedefs.h consists of function pointer + * typedefs for driver APIs defined in cuda.h. + * + * The API will return ::CUDA_SUCCESS and set the returned \p pfn to NULL if the + * requested driver function is not supported on the platform, no ABI + * compatible driver function exists for the specified \p cudaVersion or if the + * driver symbol is invalid. + * + * It will also set the optional \p symbolStatus to one of the values in + * ::CUdriverProcAddressQueryResult with the following meanings: + * - ::CU_GET_PROC_ADDRESS_SUCCESS - The requested symbol was successfully found + * based on input arguments and \p pfn is valid + * - ::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND - The requested symbol was not found + * - ::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT - The requested symbol was + * found but is not supported by cudaVersion specified + * + * The requested flags can be: + * - ::CU_GET_PROC_ADDRESS_DEFAULT: This is the default mode. This is equivalent + * to + * ::CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM if the code is compiled + * with + * --default-stream per-thread compilation flag or the macro + * CUDA_API_PER_THREAD_DEFAULT_STREAM is defined; + * ::CU_GET_PROC_ADDRESS_LEGACY_STREAM otherwise. + * - ::CU_GET_PROC_ADDRESS_LEGACY_STREAM: This will enable the search for all + * driver symbols that match the requested driver symbol name except the + * corresponding per-thread versions. + * - ::CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM: This will enable the + * search for all driver symbols that match the requested driver symbol name + * including the per-thread versions. If a per-thread version is not found, the + * API will return the legacy version of the driver function. + * + * \param symbol - The base name of the driver API function to look for. As an + * example, for the driver API ::cuMemAlloc_v2, \p symbol would be cuMemAlloc + * and \p cudaVersion would be the ABI compatible CUDA version for the _v2 + * variant. \param pfn - Location to return the function pointer to the + * requested driver function \param cudaVersion - The CUDA version to look for + * the requested driver symbol \param flags - Flags to specify search options. + * \param symbolStatus - Optional location to store the status of the search for + * \p symbol based on \p cudaVersion. See + * ::CUdriverProcAddressQueryResult for possible values. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_version_mixing + * + * \sa + * ::cudaGetDriverEntryPoint + */ +CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, + int cudaVersion, cuuint64_t flags, + CUdriverProcAddressQueryResult *symbolStatus); + +/** @} */ /* END CUDA_DRIVER_ENTRY_POINT */ + +/** + * \defgroup CUDA_COREDUMP Coredump Attributes Control API + * + * ___MANBRIEF___ coredump attribute control functions for the low-level CUDA + * API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the coredump attribute control functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * Flags for choosing a coredump attribute to get/set + */ +typedef enum CUcoredumpSettings_enum { + CU_COREDUMP_ENABLE_ON_EXCEPTION = 1, + CU_COREDUMP_TRIGGER_HOST, + CU_COREDUMP_LIGHTWEIGHT, + CU_COREDUMP_ENABLE_USER_TRIGGER, + CU_COREDUMP_FILE, + CU_COREDUMP_PIPE, + CU_COREDUMP_MAX +} CUcoredumpSettings; + +/** + * \brief Allows caller to fetch a coredump attribute value for the current + * context + * + * Returns in \p *value the requested value specified by \p attrib. It is up to + * the caller to ensure that the data type and size of \p *value matches the + * request. + * + * If the caller calls this function with \p *value equal to NULL, the size of + * the memory region (in bytes) expected for \p attrib will be placed in \p + * size. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU + * exceptions from this context will create a coredump at the location specified + * by ::CU_COREDUMP_FILE. The default value is ::false unless set to ::true + * globally or locally, or the CU_CTX_USER_COREDUMP_ENABLE flag was set during + * context creation. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true unless set to + * ::false globally or or locally. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting + * coredumps will not have a dump of GPU memory or non-reloc ELF images. The + * default value is + * ::false unless set to ::true globally or locally. + * - ::CU_COREDUMP_ENABLE_USER_TRIGGER: Bool where ::true means that a coredump + * can be created by writing to the system pipe specified by ::CU_COREDUMP_PIPE. + * The default value is ::false unless set to ::true globally or locally. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the + * location where any coredumps generated by this context will be written. The + * default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the + * machine running the CUDA applications and ::PID is the process ID of the CUDA + * application. + * - ::CU_COREDUMP_PIPE: String of up to 1023 characters that defines the name + * of the pipe that will be monitored if user-triggered coredumps are enabled. + * The default value is + * ::corepipe.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the + * machine running the CUDA application and ::PID is the process ID of the CUDA + * application. + * + * \param attrib - The enum defining which value to fetch. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa + * ::cuCoredumpGetAttributeGlobal, + * ::cuCoredumpSetAttribute, + * ::cuCoredumpSetAttributeGlobal + */ +CUresult CUDAAPI cuCoredumpGetAttribute(CUcoredumpSettings attrib, void *value, + size_t *size); + +/** + * \brief Allows caller to fetch a coredump attribute value for the entire + * application + * + * Returns in \p *value the requested value specified by \p attrib. It is up to + * the caller to ensure that the data type and size of \p *value matches the + * request. + * + * If the caller calls this function with \p *value equal to NULL, the size of + * the memory region (in bytes) expected for \p attrib will be placed in \p + * size. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU + * exceptions from this context will create a coredump at the location specified + * by ::CU_COREDUMP_FILE. The default value is ::false. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting + * coredumps will not have a dump of GPU memory or non-reloc ELF images. The + * default value is + * ::false. + * - ::CU_COREDUMP_ENABLE_USER_TRIGGER: Bool where ::true means that a coredump + * can be created by writing to the system pipe specified by ::CU_COREDUMP_PIPE. + * The default value is ::false. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the + * location where any coredumps generated by this context will be written. The + * default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the + * machine running the CUDA applications and ::PID is the process ID of the CUDA + * application. + * - ::CU_COREDUMP_PIPE: String of up to 1023 characters that defines the name + * of the pipe that will be monitored if user-triggered coredumps are enabled. + * The default value is + * ::corepipe.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the + * machine running the CUDA application and ::PID is the process ID of the CUDA + * application. + * + * \param attrib - The enum defining which value to fetch. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuCoredumpGetAttribute, + * ::cuCoredumpSetAttribute, + * ::cuCoredumpSetAttributeGlobal + */ +CUresult CUDAAPI cuCoredumpGetAttributeGlobal(CUcoredumpSettings attrib, + void *value, size_t *size); + +/** + * \brief Allows caller to set a coredump attribute value for the current + * context + * + * This function should be considered an alternate interface to the CUDA-GDB + * environment variables defined in this document: + * https://docs.nvidia.com/cuda/cuda-gdb/index.html#gpu-coredump + * + * An important design decision to note is that any coredump environment + * variable values set before CUDA initializes will take permanent precedence + * over any values set with this this function. This decision was made to ensure + * no change in behavior for any users that may be currently using these + * variables to get coredumps. + * + * \p *value shall contain the requested value specified by \p set. It is up to + * the caller to ensure that the data type and size of \p *value matches the + * request. + * + * If the caller calls this function with \p *value equal to NULL, the size of + * the memory region (in bytes) expected for \p set will be placed in \p size. + * + * /note This function will return ::CUDA_ERROR_NOT_SUPPORTED if the caller + * attempts to set + * ::CU_COREDUMP_ENABLE_ON_EXCEPTION on a GPU of with Compute Capability < 6.0. + * ::cuCoredumpSetAttributeGlobal works on those platforms as an alternative. + * + * /note ::CU_COREDUMP_ENABLE_USER_TRIGGER and ::CU_COREDUMP_PIPE cannot be set + * on a per-context basis. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU + * exceptions from this context will create a coredump at the location specified + * by ::CU_COREDUMP_FILE. The default value is ::false. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting + * coredumps will not have a dump of GPU memory or non-reloc ELF images. The + * default value is + * ::false. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the + * location where any coredumps generated by this context will be written. The + * default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the + * machine running the CUDA applications and ::PID is the process ID of the CUDA + * application. + * + * \param attrib - The enum defining which value to set. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa + * ::cuCoredumpGetAttributeGlobal, + * ::cuCoredumpGetAttribute, + * ::cuCoredumpSetAttributeGlobal + */ +CUresult CUDAAPI cuCoredumpSetAttribute(CUcoredumpSettings attrib, void *value, + size_t *size); + +/** + * \brief Allows caller to set a coredump attribute value globally + * + * This function should be considered an alternate interface to the CUDA-GDB + * environment variables defined in this document: + * https://docs.nvidia.com/cuda/cuda-gdb/index.html#gpu-coredump + * + * An important design decision to note is that any coredump environment + * variable values set before CUDA initializes will take permanent precedence + * over any values set with this this function. This decision was made to ensure + * no change in behavior for any users that may be currently using these + * variables to get coredumps. + * + * \p *value shall contain the requested value specified by \p set. It is up to + * the caller to ensure that the data type and size of \p *value matches the + * request. + * + * If the caller calls this function with \p *value equal to NULL, the size of + * the memory region (in bytes) expected for \p set will be placed in \p size. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU + * exceptions from this context will create a coredump at the location specified + * by ::CU_COREDUMP_FILE. The default value is ::false. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting + * coredumps will not have a dump of GPU memory or non-reloc ELF images. The + * default value is + * ::false. + * - ::CU_COREDUMP_ENABLE_USER_TRIGGER: Bool where ::true means that a coredump + * can be created by writing to the system pipe specified by ::CU_COREDUMP_PIPE. + * The default value is ::false. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the + * location where any coredumps generated by this context will be written. The + * default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the + * machine running the CUDA applications and ::PID is the process ID of the CUDA + * application. + * - ::CU_COREDUMP_PIPE: String of up to 1023 characters that defines the name + * of the pipe that will be monitored if user-triggered coredumps are enabled. + * This value may not be changed after ::CU_COREDUMP_ENABLE_USER_TRIGGER is set + * to ::true. The default value is ::corepipe.cuda.HOSTNAME.PID where ::HOSTNAME + * is the host name of the machine running the CUDA application and ::PID is the + * process ID of the CUDA application. + * + * \param attrib - The enum defining which value to set. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED + * + * \sa + * ::cuCoredumpGetAttribute, + * ::cuCoredumpGetAttributeGlobal, + * ::cuCoredumpSetAttribute + */ +CUresult CUDAAPI cuCoredumpSetAttributeGlobal(CUcoredumpSettings attrib, + void *value, size_t *size); + +/** @} */ /* END CUDA_COREDUMP */ + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, + const CUuuid *pExportTableId); + +/* +** ******************* GREEN CONTEXTS ********************** +*/ + +/** + * \defgroup CUDA_GREEN_CONTEXTS Green Contexts + * + * ___MANBRIEF___ Driver level API for creation and manipulation of green + * contexts + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the APIs for creation and manipulation of green + * contexts in the CUDA driver. Green contexts are a lightweight alternative to + * traditional contexts, with the ability to pass in a set of resources that + * they should be initialized with. This allows the developer to represent + * distinct spatial partitions of the GPU, provision resources for them, and + * target them via the same programming model that CUDA exposes (streams, kernel + * launches, etc.). + * + * There are 4 main steps to using these new set of APIs. + * - (1) Start with an initial set of resources, for example via + * ::cuDeviceGetDevResource. Only SM type is supported today. + * - (2) Partition this set of resources by providing them as input to a + * partition API, for example: ::cuDevSmResourceSplitByCount. + * - (3) Finalize the specification of resources by creating a descriptor via + * ::cuDevResourceGenerateDesc. + * - (4) Provision the resources and create a green context via + * ::cuGreenCtxCreate. + * + * For \p CU_DEV_RESOURCE_TYPE_SM, the partitions created have minimum SM count + * requirements, often rounding up and aligning the minCount provided to + * ::cuDevSmResourceSplitByCount. The following is a guideline for each + * architecture and may be subject to change: + * - On Compute Architecture 6.X: The minimum count is 1 SM. + * - On Compute Architecture 7.X: The minimum count is 2 SMs and must be a + * multiple of 2. + * - On Compute Architecture 8.X: The minimum count is 4 SMs and must be a + * multiple of 2. + * - On Compute Architecture 9.0+: The minimum count is 8 SMs and must be a + * multiple of 8. + * + * In the future, flags can be provided to tradeoff functional and performance + * characteristics versus finer grained SM partitions. + * + * Even if the green contexts have disjoint SM partitions, it is not guaranteed + * that the kernels launched in them will run concurrently or have forward + * progress guarantees. This is due to other resources (like HW connections, see + * ::CUDA_DEVICE_MAX_CONNECTIONS) that could cause a dependency. Additionally, + * in certain scenarios, it is possible for the workload to run on more SMs than + * was provisioned (but never less). The following are two scenarios which can + * exhibit this behavior: + * - On Volta+ MPS: When \p CUDA_MPS_ACTIVE_THREAD_PERCENTAGE is used, + * the set of SMs that are used for running kernels can be scaled up to the + * value of SMs used for the MPS client. + * - On Compute Architecture 9.x: When a module with dynamic parallelism (CDP) + * is loaded, all future kernels running under green contexts may use and share + * an additional set of 2 SMs. + * + * @{ + */ + +/*! + * \typedef typedef struct CUgreenCtx_st* CUgreenCtx + * A green context handle. This handle can be used safely from only one CPU + * thread at a time. Created via ::cuGreenCtxCreate + */ +typedef struct CUgreenCtx_st *CUgreenCtx; + +/*! + * \typedef struct CUdevResourceDesc_st* CUdevResourceDesc; + * An opaque descriptor handle. The descriptor encapsulates multiple created and + * configured resources. Created via ::cuDevResourceGenerateDesc + */ +typedef struct CUdevResourceDesc_st *CUdevResourceDesc; + +typedef enum { + CU_GREEN_CTX_DEFAULT_STREAM = 0x1, /**< Required. Creates a default stream to + use inside the green context */ +} CUgreenCtxCreate_flags; + +#define RESOURCE_ABI_VERSION 1 +#define RESOURCE_ABI_EXTERNAL_BYTES 48 + +#define _CONCAT_INNER(x, y) x##y +#define _CONCAT_OUTER(x, y) _CONCAT_INNER(x, y) + +/*! + * \typedef enum CUdevResourceType + * Type of resource + */ +typedef enum { + CU_DEV_RESOURCE_TYPE_INVALID = 0, + CU_DEV_RESOURCE_TYPE_SM = + 1, /**< Streaming multiprocessors related information */ +#ifdef __CUDA_API_VERSION_INTERNAL + CU_DEV_RESOURCE_TYPE_MAX, +#endif +} CUdevResourceType; + +/*! + * \struct CUdevSmResource + * Data for SM-related resources + */ +typedef struct CUdevSmResource_st { + unsigned int smCount; /**< The amount of streaming multiprocessors available + in this resource. This is an output parameter only, + do not write to this field. */ +} CUdevSmResource; + +/*! + * \struct CUdevResource + * A tagged union describing different resources identified by the type field. + * This structure should not be directly modified outside of the API that + * created it. \code struct { CUdevResourceType type; union { CUdevSmResource + * sm; + * }; + * }; + * \endcode + * - If \p type is \p CU_DEV_RESOURCE_TYPE_INVALID, this resource is not valid + * and cannot be further accessed. + * - If \p type is \p CU_DEV_RESOURCE_TYPE_SM, the ::CUdevSmResource structure + * \p sm is filled in. For example, \p sm.smCount will reflect the amount of + * streaming multiprocessors available in this resource. + */ +typedef struct CUdevResource_st { + CUdevResourceType + type; /**< Type of resource, dictates which union field was last set */ + unsigned char _internal_padding[92]; + union { + CUdevSmResource + sm; /**< Resource corresponding to CU_DEV_RESOURCE_TYPE_SM \p. type. */ + unsigned char _oversize[RESOURCE_ABI_EXTERNAL_BYTES]; + }; +} _CONCAT_OUTER(CUdevResource_v, RESOURCE_ABI_VERSION); +typedef _CONCAT_OUTER(CUdevResource_v, RESOURCE_ABI_VERSION) CUdevResource; + +#undef _CONCAT_INNER +#undef _CONCAT_OUTER + +#undef ABI_PER_RESOURCE_EXTERNAL_BYTES +#undef ABI_RESOURCE_VERSION + +/** + * \brief Creates a green context with a specified set of resources. + * + * This API creates a green context with the resources specified in the + * descriptor \p desc and returns it in the handle represented by \p phCtx. This + * API will retain the primary context on device \p dev, which will is released + * when the green context is destroyed. It is advised to have the primary + * context active before calling this API to avoid the heavy cost of triggering + * primary context initialization and deinitialization multiple times. + * + * The API does not set the green context current. In order to set it current, + * you need to explicitly set it current by first converting the green context + * to a CUcontext using ::cuCtxFromGreenCtx and subsequently calling + * ::cuCtxSetCurrent / ::cuCtxPushCurrent. It should be noted that a green + * context can be current to only one thread at a time. There is no internal + * synchronization to make API calls accessing the same green context from + * multiple threads work. + * + * Note: The API is not supported on 32-bit platforms. + * + * \param phCtx - Pointer for the output handle to the green context + * \param desc - Descriptor generated via ::cuDevResourceGenerateDesc which + * contains the set of resources to be used \param dev - Device on which to + * create the green context. \param flags - One of the supported green context + * creation flags. \p CU_GREEN_CTX_DEFAULT_STREAM is required. + * + * The supported flags are: + * - \p CU_GREEN_CTX_DEFAULT_STREAM : Creates a default stream to use inside the + * green context. Required. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa + * ::cuGreenCtxDestroy, + * ::cuCtxFromGreenCtx, + * ::cuCtxSetCurrent, + * ::cuCtxPushCurrent, + * ::cuDevResourceGenerateDesc, + * ::cuDevicePrimaryCtxRetain, + * ::cuCtxCreate, + * ::cuCtxCreate_v3 + */ +CUresult CUDAAPI cuGreenCtxCreate(CUgreenCtx *phCtx, CUdevResourceDesc desc, + CUdevice dev, unsigned int flags); + +/** + * \brief Destroys a green context + * + * Destroys the green context, releasing the primary context of the device that + * this green context was created for. Any resources provisioned for this green + * context (that were initially available via the resource descriptor) are + * released as well. \param hCtx - Green context to be destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa + * ::cuGreenCtxCreate, + * ::cuCtxDestroy + */ +CUresult CUDAAPI cuGreenCtxDestroy(CUgreenCtx hCtx); + +/** + * \brief Converts a green context into the primary context + * + * The API converts a green context into the primary context returned in \p + * pContext. It is important to note that the converted context \p pContext is a + * normal primary context but with the resources of the specified green context + * \p hCtx. Once converted, it can then be used to set the context current with + * ::cuCtxSetCurrent or with any of the CUDA APIs that accept a CUcontext + * parameter. + * + * Users are expected to call this API before calling any CUDA APIs that accept + * a CUcontext. Failing to do so will result in the APIs returning + * ::CUDA_ERROR_INVALID_CONTEXT. + * + * \param pContext Returned primary context with green context resources + * \param hCtx Green context to convert + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuGreenCtxCreate + */ +CUresult CUDAAPI cuCtxFromGreenCtx(CUcontext *pContext, CUgreenCtx hCtx); + +/** + * \brief Get device resources + * + * Get the \p type resources available to the \p device. + * This may often be the starting point for further partitioning or configuring + * of resources. + * + * Note: The API is not supported on 32-bit platforms. + * + * \param device - Device to get resource for + * \param resource - Output pointer to a CUdevResource structure + * \param type - Type of resource to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_RESOURCE_TYPE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDevResourceGenerateDesc + */ +CUresult CUDAAPI cuDeviceGetDevResource(CUdevice device, + CUdevResource *resource, + CUdevResourceType type); + +/** + * \brief Get context resources + * + * Get the \p type resources available to the context represented by \p hCtx + * \param hCtx - Context to get resource for + * + * Note: The API is not supported on 32-bit platforms. + * + * \param resource - Output pointer to a CUdevResource structure + * \param type - Type of resource to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_RESOURCE_TYPE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CONTEXT + * + * \sa + * ::cuDevResourceGenerateDesc + */ +CUresult CUDAAPI cuCtxGetDevResource(CUcontext hCtx, CUdevResource *resource, + CUdevResourceType type); + +/** + * \brief Get green context resources + * + * Get the \p type resources available to the green context represented by \p + * hCtx \param hCtx - Green context to get resource for \param resource - Output + * pointer to a CUdevResource structure \param type - Type of resource to + * retrieve + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_RESOURCE_TYPE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuDevResourceGenerateDesc + */ +CUresult CUDAAPI cuGreenCtxGetDevResource(CUgreenCtx hCtx, + CUdevResource *resource, + CUdevResourceType type); + +/** + * \brief Splits \p CU_DEV_RESOURCE_TYPE_SM resources. + * + * Splits \p CU_DEV_RESOURCE_TYPE_SM resources into \p nbGroups, adhering to the + * minimum SM count specified in \p minCount and the usage flags in \p useFlags. + * If \p result is NULL, the API simulates a split and provides the amount of + * groups that would be created in \p nbGroups. Otherwise, \p nbGroups must + * point to the amount of elements in \p result and on return, the API will + * overwrite \p nbGroups with the amount actually created. The groups are + * written to the array in \p result. \p nbGroups can be less than the total + * amount if a smaller number of groups is needed. + * + * This API is used to spatially partition the input resource. The input + * resource needs to come from one of + * ::cuDeviceGetDevResource, ::cuCtxGetDevResource, or + * ::cuGreenCtxGetDevResource. A limitation of the API is that the output + * results cannot be split again without first creating a descriptor and a green + * context with that descriptor. + * + * When creating the groups, the API will take into account the performance and + * functional characteristics of the input resource, and guarantee a split that + * will create a disjoint set of symmetrical partitions. This may lead to less + * groups created than purely dividing the total SM count by the \p minCount due + * to cluster requirements or alignment and granularity requirements for the + * minCount. + * + * The \p remainder set, might not have the same functional or performance + * guarantees as the groups in \p result. Its use should be carefully planned + * and future partitions of the \p remainder set are discouraged. + * + * A successful API call must either have: + * - A valid array of \p result pointers of size passed in \p nbGroups, with \p + * Input of type \p CU_DEV_RESOURCE_TYPE_SM. Value of \p minCount must be + * between 0 and the SM count specified in \p input. \p remaining and \p + * useFlags are optional. + * - NULL passed in for \p result, with a valid integer pointer in \p nbGroups + * and \p Input of type \p CU_DEV_RESOURCE_TYPE_SM. Value of \p minCount must be + * between 0 and the SM count specified in \p input. This queries the number of + * groups that would be created by the API. + * + * Note: The API is not supported on 32-bit platforms. + * + * \param result - Output array of \p CUdevResource resources. Can be NULL to + * query the number of groups. \param nbGroups - This is a pointer, specifying + * the number of groups that would be or should be created as described below. + * \param input - Input SM resource to be split. Must be a valid \p + * CU_DEV_RESOURCE_TYPE_SM resource. \param remaining - If the input resource + * cannot be cleanly split among \p nbGroups, the remaining is placed in here. + * Can be omitted (NULL) if the user does not need the remaining set. + * \param useFlags - Flags specifying how these partitions are used or which + * constraints to abide by when splitting the input. \param minCount - Minimum + * number of SMs required + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_RESOURCE_TYPE, + * ::CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION + * + * \sa + * ::cuGreenCtxGetDevResource, + * ::cuCtxGetDevResource, + * ::cuDeviceGetDevResource + */ +CUresult CUDAAPI cuDevSmResourceSplitByCount( + CUdevResource *result, unsigned int *nbGroups, const CUdevResource *input, + CUdevResource *remaining, unsigned int useFlags, unsigned int minCount); + +/** + * \brief Generate a resource descriptor + * + * Generates a resource descriptor with the set of resources specified in \p + * resources. The generated resource descriptor is necessary for the creation of + * green contexts via the ::cuGreenCtxCreate API. The API expects \p nbResources + * == 1, as there is only one type of resource and merging the same types of + * resource is currently not supported. + * + * Note: The API is not supported on 32-bit platforms. + * + * \param phDesc - Output descriptor + * \param resources - Array of resources to be included in the descriptor + * \param nbResources - Number of resources passed in \p resources + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_RESOURCE_TYPE, + * ::CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION + * + * \sa + * ::cuDevSmResourceSplitByCount + */ +CUresult CUDAAPI cuDevResourceGenerateDesc(CUdevResourceDesc *phDesc, + CUdevResource *resources, + unsigned int nbResources); + +/** + * \brief Records an event. + * + * Captures in \phEvent all the activities of the green context of \phCtx + * at the time of this call. \phEvent and \phCtx must be from the same + * CUDA context. Calls such as ::cuEventQuery() or ::cuGreenCtxWaitEvent() will + * then examine or wait for completion of the work that was captured. Uses of + * \p hCtx after this call do not modify \p hEvent. + * + * \note The API will return an error if the specified green context \p hCtx + * has a stream in the capture mode. In such a case, the call will invalidate + * all the conflicting captures. + * + * \param hCtx - Green context to record event for + * \param hEvent - Event to record + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \sa + * ::cuGreenCtxWaitEvent, + * ::cuEventRecord + */ +CUresult CUDAAPI cuGreenCtxRecordEvent(CUgreenCtx hCtx, CUevent hEvent); + +/** + * \brief Make a green context wait on an event + * + * Makes all future work submitted to green context \phCtx wait for all work + * captured in \phEvent. The synchronization will be performed on the device + * and will not block the calling CPU thread. See ::cuGreenCtxRecordEvent() + * for details on what is captured by an event. + * + * \note The API will return an error and invalidate the capture if the + * specified event \p hEvent is part of an ongoing capture sequence. + * + * \param hCtx - Green context to wait + * \param hEvent - Event to wait on (may not be NULL) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \sa + * ::cuGreenCtxRecordEvent, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGreenCtxWaitEvent(CUgreenCtx hCtx, CUevent hEvent); + +/** + * \brief Query the green context associated with a stream + * + * Returns the CUDA green context that the stream is associated with, or NULL if + * the stream is not associated with any green context. + * + * The stream handle \p hStream can refer to any of the following: + *
    + *
  • + * a stream created via any of the CUDA driver APIs such as ::cuStreamCreate. + * If during stream creation the context that was active in the calling thread + * was obtained with cuCtxFromGreenCtx, that green context is returned in \p + * phCtx. Otherwise, \p *phCtx is set to NULL instead. + *
  • + *
  • + * special stream such as the NULL stream or ::CU_STREAM_LEGACY. + * In that case if context that is active in the calling thread was obtained + * with cuCtxFromGreenCtx, that green context is returned. + * Otherwise, \p *phCtx is set to NULL instead. + *
  • + *
+ * Passing an invalid handle will result in undefined behavior. + * + * \param hStream - Handle to the stream to be queried + * \param phCtx - Returned green context associated with the stream + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreate, + * ::cudaStreamCreateWithFlags + */ +CUresult CUDAAPI cuStreamGetGreenCtx(CUstream hStream, CUgreenCtx *phCtx); + +/** @} */ + +/* +** *************** END CUDA_GREEN_CONTEXTS ***************** +*/ + +/** + * CUDA API versioning support + */ +#if defined(__CUDA_API_VERSION_INTERNAL) +#undef cuMemHostRegister +#undef cuGraphicsResourceSetMapFlags +#undef cuLinkCreate +#undef cuLinkAddData +#undef cuLinkAddFile +#undef cuDeviceTotalMem +#undef cuCtxCreate +#undef cuModuleGetGlobal +#undef cuMemGetInfo +#undef cuMemAlloc +#undef cuMemAllocPitch +#undef cuMemFree +#undef cuMemGetAddressRange +#undef cuMemAllocHost +#undef cuMemHostGetDevicePointer +#undef cuMemcpyHtoD +#undef cuMemcpyDtoH +#undef cuMemcpyDtoD +#undef cuMemcpyDtoA +#undef cuMemcpyAtoD +#undef cuMemcpyHtoA +#undef cuMemcpyAtoH +#undef cuMemcpyAtoA +#undef cuMemcpyHtoAAsync +#undef cuMemcpyAtoHAsync +#undef cuMemcpy2D +#undef cuMemcpy2DUnaligned +#undef cuMemcpy3D +#undef cuMemcpyHtoDAsync +#undef cuMemcpyDtoHAsync +#undef cuMemcpyDtoDAsync +#undef cuMemcpy2DAsync +#undef cuMemcpy3DAsync +#undef cuMemsetD8 +#undef cuMemsetD16 +#undef cuMemsetD32 +#undef cuMemsetD2D8 +#undef cuMemsetD2D16 +#undef cuMemsetD2D32 +#undef cuArrayCreate +#undef cuArrayGetDescriptor +#undef cuArray3DCreate +#undef cuArray3DGetDescriptor +#undef cuTexRefSetAddress +#undef cuTexRefSetAddress2D +#undef cuTexRefGetAddress +#undef cuGraphicsResourceGetMappedPointer +#undef cuCtxDestroy +#undef cuCtxPopCurrent +#undef cuCtxPushCurrent +#undef cuStreamDestroy +#undef cuEventDestroy +#undef cuMemcpy +#undef cuMemcpyAsync +#undef cuMemcpyPeer +#undef cuMemcpyPeerAsync +#undef cuMemcpy3DPeer +#undef cuMemcpy3DPeerAsync +#undef cuMemsetD8Async +#undef cuMemsetD16Async +#undef cuMemsetD32Async +#undef cuMemsetD2D8Async +#undef cuMemsetD2D16Async +#undef cuMemsetD2D32Async +#undef cuStreamGetPriority +#undef cuStreamGetId +#undef cuStreamGetFlags +#undef cuStreamGetCtx +#undef cuStreamWaitEvent +#undef cuStreamAddCallback +#undef cuStreamAttachMemAsync +#undef cuStreamQuery +#undef cuStreamSynchronize +#undef cuEventRecord +#undef cuEventRecordWithFlags +#undef cuLaunchKernel +#undef cuLaunchKernelEx +#undef cuLaunchHostFunc +#undef cuGraphicsMapResources +#undef cuGraphicsUnmapResources +#undef cuStreamWriteValue32 +#undef cuStreamWaitValue32 +#undef cuStreamWriteValue64 +#undef cuStreamWaitValue64 +#undef cuStreamBatchMemOp +#undef cuStreamWriteValue32_v2 +#undef cuStreamWaitValue32_v2 +#undef cuStreamWriteValue64_v2 +#undef cuStreamWaitValue64_v2 +#undef cuStreamBatchMemOp_v2 +#undef cuMemPrefetchAsync +#undef cuMemPrefetchAsync_v2 +#undef cuLaunchCooperativeKernel +#undef cuSignalExternalSemaphoresAsync +#undef cuWaitExternalSemaphoresAsync +#undef cuStreamBeginCapture +#undef cuStreamBeginCaptureToGraph +#undef cuStreamEndCapture +#undef cuStreamIsCapturing +#undef cuStreamGetCaptureInfo +#undef cuStreamGetCaptureInfo_v2 +#undef cuStreamGetCaptureInfo_v3 +#undef cuGraphInstantiateWithParams +#undef cuGraphExecUpdate +#undef cuGraphUpload +#undef cuGraphLaunch +#undef cuDevicePrimaryCtxRelease +#undef cuDevicePrimaryCtxReset +#undef cuDevicePrimaryCtxSetFlags +#undef cuIpcOpenMemHandle +#undef cuStreamCopyAttributes +#undef cuStreamSetAttribute +#undef cuStreamGetAttribute +#undef cuGraphInstantiate +#undef cuGraphAddKernelNode +#undef cuGraphKernelNodeGetParams +#undef cuGraphKernelNodeSetParams +#undef cuGraphExecKernelNodeSetParams +#undef cuMemMapArrayAsync +#undef cuMemFreeAsync +#undef cuMemAllocAsync +#undef cuMemAllocFromPoolAsync +#undef cuStreamUpdateCaptureDependencies +#undef cuStreamUpdateCaptureDependencies_v2 +#undef cuGetProcAddress + +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, + unsigned int Flags); +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, + unsigned int flags); +CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, + void **optionValues, CUlinkState *stateOut); +CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, + void **optionValues); +CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, + const char *path, unsigned int numOptions, + CUjit_option *options, void **optionValues); +CUresult CUDAAPI cuTexRefSetAddress2D_v2(CUtexref hTexRef, + const CUDA_ARRAY_DESCRIPTOR *desc, + CUdeviceptr dptr, size_t Pitch); + +typedef unsigned int CUdeviceptr_v1; + +typedef struct CUDA_MEMCPY2D_v1_st { + unsigned int srcXInBytes; /**< Source X in bytes */ + unsigned int srcY; /**< Source Y */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr_v1 srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ + + unsigned int dstXInBytes; /**< Destination X in bytes */ + unsigned int dstY; /**< Destination Y */ + CUmemorytype + dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr_v1 dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ + + unsigned int WidthInBytes; /**< Width of 2D memory copy in bytes */ + unsigned int Height; /**< Height of 2D memory copy */ +} CUDA_MEMCPY2D_v1; + +typedef struct CUDA_MEMCPY3D_v1_st { + unsigned int srcXInBytes; /**< Source X in bytes */ + unsigned int srcY; /**< Source Y */ + unsigned int srcZ; /**< Source Z */ + unsigned int srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr_v1 srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + void *reserved0; /**< Must be NULL */ + unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ + unsigned int srcHeight; /**< Source height (ignored when src is array; may be + 0 if Depth==1) */ + + unsigned int dstXInBytes; /**< Destination X in bytes */ + unsigned int dstY; /**< Destination Y */ + unsigned int dstZ; /**< Destination Z */ + unsigned int dstLOD; /**< Destination LOD */ + CUmemorytype + dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr_v1 dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + void *reserved1; /**< Must be NULL */ + unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ + unsigned int dstHeight; /**< Destination height (ignored when dst is array; + may be 0 if Depth==1) */ + + unsigned int WidthInBytes; /**< Width of 3D memory copy in bytes */ + unsigned int Height; /**< Height of 3D memory copy */ + unsigned int Depth; /**< Depth of 3D memory copy */ +} CUDA_MEMCPY3D_v1; + +typedef struct CUDA_ARRAY_DESCRIPTOR_v1_st { + unsigned int Width; /**< Width of array */ + unsigned int Height; /**< Height of array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ +} CUDA_ARRAY_DESCRIPTOR_v1; + +typedef struct CUDA_ARRAY3D_DESCRIPTOR_v1_st { + unsigned int Width; /**< Width of 3D array */ + unsigned int Height; /**< Height of 3D array */ + unsigned int Depth; /**< Depth of 3D array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + unsigned int Flags; /**< Flags */ +} CUDA_ARRAY3D_DESCRIPTOR_v1; + +CUresult CUDAAPI cuDeviceTotalMem(unsigned int *bytes, CUdevice dev); +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr_v1 *dptr, unsigned int *bytes, + CUmodule hmod, const char *name); +CUresult CUDAAPI cuMemGetInfo(unsigned int *free, unsigned int *total); +CUresult CUDAAPI cuMemAlloc(CUdeviceptr_v1 *dptr, unsigned int bytesize); +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr_v1 *dptr, unsigned int *pPitch, + unsigned int WidthInBytes, unsigned int Height, + unsigned int ElementSizeBytes); +CUresult CUDAAPI cuMemFree(CUdeviceptr_v1 dptr); +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr_v1 *pbase, + unsigned int *psize, CUdeviceptr_v1 dptr); +CUresult CUDAAPI cuMemAllocHost(void **pp, unsigned int bytesize); +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr_v1 *pdptr, void *p, + unsigned int Flags); +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr_v1 dstDevice, const void *srcHost, + unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr_v1 srcDevice, + unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr_v1 dstDevice, + CUdeviceptr_v1 srcDevice, unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, unsigned int dstOffset, + CUdeviceptr_v1 srcDevice, unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr_v1 dstDevice, CUarray srcArray, + unsigned int srcOffset, unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, unsigned int dstOffset, + const void *srcHost, unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, + unsigned int srcOffset, unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, unsigned int dstOffset, + CUarray srcArray, unsigned int srcOffset, + unsigned int ByteCount); +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, unsigned int dstOffset, + const void *srcHost, unsigned int ByteCount, + CUstream hStream); +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, + unsigned int srcOffset, + unsigned int ByteCount, CUstream hStream); +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D_v1 *pCopy); +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D_v1 *pCopy); +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D_v1 *pCopy); +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr_v1 dstDevice, + const void *srcHost, unsigned int ByteCount, + CUstream hStream); +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr_v1 srcDevice, + unsigned int ByteCount, CUstream hStream); +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr_v1 dstDevice, + CUdeviceptr_v1 srcDevice, + unsigned int ByteCount, CUstream hStream); +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D_v1 *pCopy, + CUstream hStream); +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D_v1 *pCopy, + CUstream hStream); +CUresult CUDAAPI cuMemsetD8(CUdeviceptr_v1 dstDevice, unsigned char uc, + unsigned int N); +CUresult CUDAAPI cuMemsetD16(CUdeviceptr_v1 dstDevice, unsigned short us, + unsigned int N); +CUresult CUDAAPI cuMemsetD32(CUdeviceptr_v1 dstDevice, unsigned int ui, + unsigned int N); +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, + unsigned char uc, unsigned int Width, + unsigned int Height); +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, + unsigned short us, unsigned int Width, + unsigned int Height); +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, + unsigned int ui, unsigned int Width, + unsigned int Height); +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, + const CUDA_ARRAY_DESCRIPTOR_v1 *pAllocateArray); +CUresult CUDAAPI cuArrayGetDescriptor( + CUDA_ARRAY_DESCRIPTOR_v1 *pArrayDescriptor, CUarray hArray); +CUresult CUDAAPI cuArray3DCreate( + CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR_v1 *pAllocateArray); +CUresult CUDAAPI cuArray3DGetDescriptor( + CUDA_ARRAY3D_DESCRIPTOR_v1 *pArrayDescriptor, CUarray hArray); +CUresult CUDAAPI cuTexRefSetAddress(unsigned int *ByteOffset, CUtexref hTexRef, + CUdeviceptr_v1 dptr, unsigned int bytes); +CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, + const CUDA_ARRAY_DESCRIPTOR_v1 *desc, + CUdeviceptr_v1 dptr, unsigned int Pitch); +CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr_v1 *pdptr, CUtexref hTexRef); +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer( + CUdeviceptr_v1 *pDevPtr, unsigned int *pSize, CUgraphicsResource resource); + +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx); +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx); +CUresult CUDAAPI cuStreamDestroy(CUstream hStream); +CUresult CUDAAPI cuEventDestroy(CUevent hEvent); +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); + +CUresult CUDAAPI cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount); +CUresult CUDAAPI cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount); +CUresult CUDAAPI cuMemcpyDtoD_v2(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount); +CUresult CUDAAPI cuMemcpyDtoA_v2(CUarray dstArray, size_t dstOffset, + CUdeviceptr srcDevice, size_t ByteCount); +CUresult CUDAAPI cuMemcpyAtoD_v2(CUdeviceptr dstDevice, CUarray srcArray, + size_t srcOffset, size_t ByteCount); +CUresult CUDAAPI cuMemcpyHtoA_v2(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount); +CUresult CUDAAPI cuMemcpyAtoH_v2(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount); +CUresult CUDAAPI cuMemcpyAtoA_v2(CUarray dstArray, size_t dstOffset, + CUarray srcArray, size_t srcOffset, + size_t ByteCount); +CUresult CUDAAPI cuMemcpyHtoAAsync_v2(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + CUstream hStream); +CUresult CUDAAPI cuMemcpyAtoHAsync_v2(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount, + CUstream hStream); +CUresult CUDAAPI cuMemcpy2D_v2(const CUDA_MEMCPY2D *pCopy); +CUresult CUDAAPI cuMemcpy2DUnaligned_v2(const CUDA_MEMCPY2D *pCopy); +CUresult CUDAAPI cuMemcpy3D_v2(const CUDA_MEMCPY3D *pCopy); +CUresult CUDAAPI cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, + const void *srcHost, size_t ByteCount, + CUstream hStream); +CUresult CUDAAPI cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream); +CUresult CUDAAPI cuMemcpyDtoDAsync_v2(CUdeviceptr dstDevice, + CUdeviceptr srcDevice, size_t ByteCount, + CUstream hStream); +CUresult CUDAAPI cuMemcpy2DAsync_v2(const CUDA_MEMCPY2D *pCopy, + CUstream hStream); +CUresult CUDAAPI cuMemcpy3DAsync_v2(const CUDA_MEMCPY3D *pCopy, + CUstream hStream); +CUresult CUDAAPI cuMemsetD8_v2(CUdeviceptr dstDevice, unsigned char uc, + size_t N); +CUresult CUDAAPI cuMemsetD16_v2(CUdeviceptr dstDevice, unsigned short us, + size_t N); +CUresult CUDAAPI cuMemsetD32_v2(CUdeviceptr dstDevice, unsigned int ui, + size_t N); +CUresult CUDAAPI cuMemsetD2D8_v2(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, size_t Height); +CUresult CUDAAPI cuMemsetD2D16_v2(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height); +CUresult CUDAAPI cuMemsetD2D32_v2(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, size_t Height); +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount); +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, + size_t ByteCount, CUstream hStream); +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount); +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount, CUstream hStream); +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy); +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, + CUstream hStream); + +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, + size_t N, CUstream hStream); +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, + size_t N, CUstream hStream); +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, + size_t N, CUstream hStream); +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, + size_t Height, CUstream hStream); +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height, CUstream hStream); +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, + size_t Height, CUstream hStream); + +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); +CUresult CUDAAPI cuStreamGetId(CUstream hStream, unsigned long long *streamId); +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); +CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, + unsigned int Flags); +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, + CUstreamCallback callback, void *userData, + unsigned int flags); +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, + size_t length, unsigned int flags); +CUresult CUDAAPI cuStreamQuery(CUstream hStream); +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); +CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, + unsigned int flags); +CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra); +CUresult CUDAAPI cuLaunchKernelEx(const CUlaunchConfig *config, CUfunction f, + void **kernelParams, void **extra); +CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, + void *userData); +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream); +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream); +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags); + +CUresult CUDAAPI cuStreamWriteValue32_ptsz(CUstream stream, CUdeviceptr addr, + cuuint32_t value, + unsigned int flags); +CUresult CUDAAPI cuStreamWaitValue32_ptsz(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWriteValue64_ptsz(CUstream stream, CUdeviceptr addr, + cuuint64_t value, + unsigned int flags); +CUresult CUDAAPI cuStreamWaitValue64_ptsz(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); +CUresult CUDAAPI cuStreamBatchMemOp_ptsz(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags); + +CUresult CUDAAPI cuStreamWriteValue32_v2(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWaitValue32_v2(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWriteValue64_v2(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); +CUresult CUDAAPI cuStreamWaitValue64_v2(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags); +CUresult CUDAAPI cuStreamBatchMemOp_v2(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags); +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, + CUdevice dstDevice, CUstream hStream); +CUresult CUDAAPI cuMemPrefetchAsync_v2(CUdeviceptr devPtr, size_t count, + CUmemLocation location, + unsigned int flags, CUstream hStream); +CUresult CUDAAPI cuLaunchCooperativeKernel( + CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams); +CUresult CUDAAPI cuSignalExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream); +CUresult CUDAAPI cuWaitExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream); +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream); +CUresult CUDAAPI cuStreamBeginCapture_ptsz(CUstream hStream); +CUresult CUDAAPI cuStreamBeginCapture_v2(CUstream hStream, + CUstreamCaptureMode mode); +CUresult CUDAAPI cuStreamBeginCaptureToGraph( + CUstream hStream, CUgraph hGraph, const CUgraphNode *dependencies, + const CUgraphEdgeData *dependencyData, size_t numDependencies, + CUstreamCaptureMode mode); +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, + CUstreamCaptureStatus *captureStatus); +CUresult CUDAAPI cuStreamGetCaptureInfo( + CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out); +CUresult CUDAAPI cuStreamGetCaptureInfo_ptsz( + CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out); +CUresult CUDAAPI cuStreamGetCaptureInfo_v2( + CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out, CUgraph *graph_out, + const CUgraphNode **dependencies_out, size_t *numDependencies_out); +CUresult CUDAAPI cuStreamGetCaptureInfo_v3( + CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out, CUgraph *graph_out, + const CUgraphNode **dependencies_out, const CUgraphEdgeData **edgeData_out, + size_t *numDependencies_out); +CUresult CUDAAPI cuGraphAddKernelNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); +CUresult CUDAAPI cuGraphKernelNodeGetParams( + CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); +CUresult CUDAAPI cuGraphKernelNodeSetParams( + CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); +CUresult CUDAAPI +cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); +CUresult CUDAAPI +cuGraphInstantiateWithParams(CUgraphExec *phGraphExec, CUgraph hGraph, + CUDA_GRAPH_INSTANTIATE_PARAMS *instantiateParams); +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, + CUgraphNode *hErrorNode_out, + CUgraphExecUpdateResult *updateResult_out); +CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraph, CUstream hStream); +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraph, CUstream hStream); +CUresult CUDAAPI cuStreamCopyAttributes(CUstream dstStream, CUstream srcStream); +CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, + CUstreamAttrValue *value); +CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, + const CUstreamAttrValue *param); + +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags); +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, + CUgraphNode *phErrorNode, char *logBuffer, + size_t bufferSize); +CUresult CUDAAPI cuGraphInstantiate_v2(CUgraphExec *phGraphExec, CUgraph hGraph, + CUgraphNode *phErrorNode, + char *logBuffer, size_t bufferSize); + +CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, + unsigned int count, CUstream hStream); + +CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); +CUresult CUDAAPI cuMemAllocAsync(CUdeviceptr *dptr, size_t bytesize, + CUstream hStream); +CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, + CUmemoryPool pool, CUstream hStream); + +CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, + CUgraphNode *dependencies, + size_t numDependencies, + unsigned int flags); +CUresult CUDAAPI cuStreamUpdateCaptureDependencies_v2( + CUstream hStream, CUgraphNode *dependencies, + const CUgraphEdgeData *dependencyData, size_t numDependencies, + unsigned int flags); +CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, + int cudaVersion, cuuint64_t flags); + +#elif defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) +static inline CUresult +cuGetProcAddress_v2_ptsz(const char *symbol, void **funcPtr, int driverVersion, + cuuint64_t flags, + CUdriverProcAddressQueryResult *symbolStatus) { + const int procAddressMask = (CU_GET_PROC_ADDRESS_LEGACY_STREAM | + CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM); + if ((flags & procAddressMask) == 0) { + flags |= CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM; + } + return cuGetProcAddress_v2(symbol, funcPtr, driverVersion, flags, + symbolStatus); +} +#define cuGetProcAddress_v2 cuGetProcAddress_v2_ptsz +#endif + +#ifdef __cplusplus +} +#endif + +#if defined(__GNUC__) +#if defined(__CUDA_API_PUSH_VISIBILITY_DEFAULT) +#pragma GCC visibility pop +#endif +#endif + +#undef __CUDA_DEPRECATED + +#endif /* __cuda_cuda_h__ */ diff --git a/tilelang/original/src/target/intrin_rule_cuda.cc b/tilelang/original/src/target/intrin_rule_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..1aacd72048ee721e8c32ee29fee742f25a070349 --- /dev/null +++ b/tilelang/original/src/target/intrin_rule_cuda.cc @@ -0,0 +1,139 @@ +/*! + * \file intrin_rule_cuda.cc + * \brief CUDA intrinsic rules. + */ +#include +#include + +#include "../support/ffi_aliases.h" +#include "target/intrin_rule.h" + +namespace tvm { +namespace codegen { +namespace intrin { +// Add float suffix to the intrinsics, CUDA fast math. +using tir::FLowerIntrinsic; + +struct CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float()) { + switch (t.bits()) { + case 64: + return name; + case 32: + return name + 'f'; + case 16: { + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } + } + default: + return ""; + } + } else if (t.is_bfloat16()) { + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } + } else if (t.is_int() || t.is_uint()) { + switch (t.bits()) { + case 32: + return "__" + name; + case 64: + return "__" + name + "ll"; + default: + return ""; + } + } + return ""; + } +}; + +struct CUDAFastMath : public CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float() && t.bits() == 32) { + return "__" + name + 'f'; + } else { + return CUDAMath::operator()(t, name); + } + return ""; + } +}; + +struct CUDAFastMathTan : public CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float()) { + switch (t.bits()) { + case 64: + return name; + // `__tanf` seems to produce some values too deviant from numpy tan + // version. So, let's use just `tanf` instead. + case 32: + return name + 'f'; + case 16: + return 'h' + name; + default: + return ""; + } + } + return ""; + } +}; + +struct CUDAPopcount { + std::string operator()(DataType t, std::string name) const { + if (t.is_uint()) { + switch (t.bits()) { + case 32: + return "__popc"; + case 64: + return "__popcll"; + default: + return ""; + } + } + return ""; + } +}; + +struct CUDAWarpIntrinsic { + const Op operator()(DataType t, const Op &orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.cuda.__shfl_sync"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.cuda.__shfl_up_sync"); + } else { + ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.cuda.__shfl_down_sync"); + } + } +}; + +static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) { + const CallNode *call = e.as(); + return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); +} + +template static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) { + const CallNode *call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + Array cuda_args{ + {call->args[0], call->args[1], call->args[2], call->args[3]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); +} + +TVM_REGISTER_OP("tir.rsqrt") + .set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); + +} // namespace intrin +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/intrin_rule_hip.cc b/tilelang/original/src/target/intrin_rule_hip.cc new file mode 100644 index 0000000000000000000000000000000000000000..ef4d6b6001d009807b8103e175c40666a4829a06 --- /dev/null +++ b/tilelang/original/src/target/intrin_rule_hip.cc @@ -0,0 +1,299 @@ +/*! + * \file intrin_rule_hip.cc + * \brief HIP intrinsic rules. + */ +#include +#include + +#include "../support/ffi_aliases.h" +#include "target/intrin_rule.h" + +namespace tvm { +namespace codegen { +namespace intrin { +// Add float suffix to the intrinsics, HIP fast math. +using tir::FLowerIntrinsic; + +struct HIPMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float()) { + switch (t.bits()) { + case 64: + return name; + case 32: + return name + 'f'; + case 16: { + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } + } + default: + return ""; + } + } else if (t.is_bfloat16()) { + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } + } else if (t.is_int() || t.is_uint()) { + switch (t.bits()) { + case 32: + return "__" + name; + case 64: + return "__" + name + "ll"; + default: + return ""; + } + } + return ""; + } +}; + +struct HIPFastMath : public HIPMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float() && t.bits() == 32) { + return "__" + name + 'f'; + } else { + return HIPMath::operator()(t, name); + } + return ""; + } +}; + +struct HIPFastMathTan : public HIPMath { + std::string operator()(DataType t, std::string name) const { + if (t.is_float()) { + switch (t.bits()) { + case 64: + return name; + case 32: + return name + 'f'; + case 16: + return std::string("h") + name; + default: + return ""; + } + } + return ""; + } +}; + +struct HIPPopcount { + std::string operator()(DataType t, std::string name) const { + if (t.is_uint()) { + switch (t.bits()) { + case 32: + return "__popc"; + case 64: + return "__popcll"; + default: + return ""; + } + } + return ""; + } +}; + +struct HIPWarpIntrinsic { + const Op operator()(DataType t, const Op &orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.hip.__shfl_sync"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.hip.__shfl_up_sync"); + } else { + ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.hip.__shfl_down_sync"); + } + } +}; + +static PrimExpr DispatchHIPWarpActiveMask(const PrimExpr &e) { + const CallNode *call = e.as(); + ICHECK(call != nullptr); + return Call(call->dtype, Op::Get("tir.hip.__activemask"), {}); +} + +template static PrimExpr DispatchHIPShuffle(const PrimExpr &e) { + // NOLINTBEGIN(clang-analyzer-cplusplus.InnerPointer) + const CallNode *call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + Array hip_args{ + {call->args[0], call->args[1], call->args[2], call->args[3]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), hip_args); + // NOLINTEND(clang-analyzer-cplusplus.InnerPointer) +} + +TVM_REGISTER_OP("tir.clz").set_attr( + "hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.floor") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.ceil") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.trunc") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.fabs") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.round") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.nearbyint") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp2") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp10") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.erf").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.log").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.log2") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.log10") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.tan").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.cos").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.cosh") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.sin").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.sinh") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.atan") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.tanh") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.sqrt") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.pow").set_attr( + "hip.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.popcount") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("hip.FLowerIntrinsic", + DispatchHIPShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") + .set_attr("hip.FLowerIntrinsic", + DispatchHIPShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") + .set_attr("hip.FLowerIntrinsic", + DispatchHIPShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_activemask") + .set_attr("hip.FLowerIntrinsic", + DispatchHIPWarpActiveMask); + +TVM_REGISTER_OP("tir.fmod") + .set_attr("hip.FLowerIntrinsic", + DispatchPureExtern); + +// Register low-level builtin ops. +TVM_REGISTER_OP("tir.hip.__shfl") + .set_num_inputs(3) + .add_argument("var", "Expr", "Value to shuffle") + .add_argument("lane", "Expr", "Source lane") + .add_argument("width", "Expr", "Warp width") + .set_attr("TGlobalSymbol", "__shfl") + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.hip.__shfl_sync") + .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") + .add_argument("width", "Expr", + "The warp thread width, must be a power of 2.") + .set_attr("TGlobalSymbol", "__shfl_sync") + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)) + .set_attr("hip.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.hip.__shfl_up_sync") + .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") + .add_argument("width", "Expr", + "The warp thread width, must be a power of 2.") + .set_attr("TGlobalSymbol", "__shfl_up_sync") + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)) + .set_attr("hip.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.hip.__shfl_down_sync") + .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", + "The source lane id offset to be subtracted.") + .add_argument("width", "Expr", + "The warp thread width, must be a power of 2.") + .set_attr("TGlobalSymbol", "__shfl_down_sync") + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)) + .set_attr("hip.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.hip.__activemask") + .set_num_inputs(0) + .set_attr("TGlobalSymbol", "__activemask") + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)) + .set_attr("hip.need_warp_shuffle", true); + +} // namespace intrin +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/ptx.cc b/tilelang/original/src/target/ptx.cc new file mode 100644 index 0000000000000000000000000000000000000000..53f83ded93d3dd6f66f133bb53ded63ae89e0fd4 --- /dev/null +++ b/tilelang/original/src/target/ptx.cc @@ -0,0 +1,1548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx.cc + */ + +#include "ptx.h" + +#include +#include +#include +#include +#include + +namespace tvm::tl { +namespace codegen { + +// PTX related data structures and functions. +namespace ptx { + +static const char *enum_to_str[] = { + "kInt4", "kUInt4", "kInt8", "kUInt8", "kInt16", + "kUInt16", "kInt32", "kUInt32", "kInt64", "kUInt64", + "kFloat8_e4m3", "kFloat8_e5m2", "kFloat16", "kBFloat16", "kFloat16x2", + "kFloat32", "kTensorFloat32", "kFloat64", "kBit1", "kBit8", + "kBit16", "kBit32", "kBit64"}; + +static const char *dtype_str[] = { + ".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32", + ".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32", + ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, + 64, 64, 8, 8, 16, 16, 32, 32, + 32, 64, 1, 8, 16, 32, 64}; + +/*! + * \brief Create PTX data type from string. + */ +DataType DTypeFromString(const std::string str) { + if (str == "int4" || str == ".s4") { + return DataType::kInt4; + } else if (str == "uint4" || str == ".u4") { + return DataType::kUInt4; + } else if (str == "int8" || str == ".s8") { + return DataType::kInt8; + } else if (str == "uint8" || str == ".u8") { + return DataType::kUInt8; + } else if (str == "int16" || str == ".s16") { + return DataType::kInt16; + } else if (str == "uint16" || str == ".u16") { + return DataType::kUInt16; + } else if (str == "int32" || str == ".s32") { + return DataType::kInt32; + } else if (str == "uint32" || str == ".u32") { + return DataType::kUInt32; + } else if (str == "int64" || str == ".s64") { + return DataType::kInt64; + } else if (str == "uint64" || str == ".u64") { + return DataType::kUInt64; + } else if (str == "float8_e4m3" || str == "e4m3" || str == ".e4m3") { + return DataType::kFloat8_e4m3; + } else if (str == "float8_e5m2" || str == "e5m2" || str == ".e5m2") { + return DataType::kFloat8_e5m2; + } else if (str == "float16" || str == "fp16" || str == ".f16") { + return DataType::kFloat16; + } else if (str == "bfloat16" || str == "bf16") { + return DataType::kBFloat16; + } else if (str == ".f16x2") { + return DataType::kFloat16x2; + } else if (str == "float32" || str == "fp32" || str == ".f32") { + return DataType::kFloat32; + } else if (str == "tf32") { + return DataType::kTensorFloat32; + } else if (str == "float64" || str == "fp64" || str == ".f64") { + return DataType::kFloat64; + } else if (str == "int1" || str == ".b1") { + return DataType::kBit1; + } else if (str == ".b8") { + return DataType::kBit8; + } else if (str == ".b16") { + return DataType::kBit16; + } else if (str == ".b32") { + return DataType::kBit32; + } else if (str == ".b64") { + return DataType::kBit64; + } else { + LOG(FATAL) << "Unrecognized PTX data type " << str; + } +} + +std::string DTypeEnumToString(const ptx::DataType &dtype) { + return "tl::DataType::" + std::string(enum_to_str[static_cast(dtype)]); +} + +std::string DTypeEnumToString(const std::string &dtype) { + return "tl::DataType::" + + std::string(enum_to_str[static_cast(DTypeFromString(dtype))]); +} + +/*! + * \brief Get the string representation of given PTX data type. + */ +inline std::string DTypeToString(DataType dtype) { + return dtype_str[static_cast(dtype)]; +} + +/*! + * \brief Get the number of bits of given PTX data type. + */ +inline uint32_t DTypeBits(DataType dtype) { + return num_bits[static_cast(dtype)]; +} + +inline bool DTypeIsInteger(DataType dtype) { + return dtype == DataType::kInt4 || dtype == DataType::kInt8 || + dtype == DataType::kInt16 || dtype == DataType::kInt32 || + dtype == DataType::kInt64 || dtype == DataType::kUInt4 || + dtype == DataType::kUInt8 || dtype == DataType::kUInt16 || + dtype == DataType::kUInt32 || dtype == DataType::kUInt64; +} + +/*! + * \brief Extract the value m, n, k from string m*n*k* + */ +std::tuple ParseMMAShape(const std::string &str) { + size_t pos_m = str.find('m'), pos_n = str.find('n'), pos_k = str.find('k'); + CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) + << "Cannot parse MMA shape " << str; + int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), + n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), + k = std::stoi(str.substr(pos_k + 1)); + return std::make_tuple(m, n, k); +} + +/*! + * \brief Layout Type + */ +enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 }; + +/*! + * \brief Parse layout type + */ +LayoutType LayoutTypeFromString(const std::string &str) { + if (str == "row") { + return LayoutType::kRowMajor; + } else if (str == "col") { + return LayoutType::kColumnMajor; + } else { + LOG(FATAL) << "Unrecognized layout type " << str; + } +} + +/*! + * \brief Parse layout type from bool. + */ +LayoutType LayoutTypeFromBool(const bool &layout) { + if (layout) { + return LayoutType::kRowMajor; + } else { + return LayoutType::kColumnMajor; + } +} + +static const char *layout_type_str[] = {"row", "col"}; + +/*! + * \brief Convert layout type to string. + */ +inline std::string LayoutTypeToString(LayoutType layout) { + return layout_type_str[static_cast(layout)]; +} + +/*! + * \brief MMA Configurations, used to determine validity. + */ +struct MMAConfig { + explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, + bool sparse) + : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), + sparse(sparse) {} + int m, n, k; + DataType dtype_mul; + bool use_bit_op; + bool sparse; + inline bool operator==(const MMAConfig &other) { + return m == other.m && n == other.n && k == other.k && + dtype_mul == other.dtype_mul && use_bit_op == other.use_bit_op && + sparse == other.sparse; + } +}; + +/*! + * \brief Valid MMA configurations + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape + */ +const MMAConfig valid_mma_configs[] = { + MMAConfig(8, 8, 4, DataType::kFloat64, false, false), + MMAConfig(8, 8, 4, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 4, DataType::kFloat32, false, false), + MMAConfig(16, 8, 8, DataType::kFloat32, false, false), + MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false), + MMAConfig(8, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 32, DataType::kInt8, false, false), + MMAConfig(8, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 32, DataType::kUInt8, false, false), + MMAConfig(8, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 64, DataType::kInt4, false, false), + MMAConfig(8, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 64, DataType::kUInt4, false, false), + MMAConfig(8, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 256, DataType::kBit1, true, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kFloat16, false, true), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 32, DataType::kInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt8, false, true), + MMAConfig(16, 8, 32, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt4, false, true), + MMAConfig(16, 8, 128, DataType::kInt4, false, true), + MMAConfig(16, 8, 64, DataType::kUInt4, false, true), + MMAConfig(16, 8, 128, DataType::kUInt4, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e4m3, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e4m3, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e5m2, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e5m2, false, true), +}; + +struct WGMMAConfig { + explicit WGMMAConfig(int m, int n, int k, DataType dtype_a, DataType dtype_b, + DataType dtype_c, bool sparse) + : m(m), n(n), k(k), dtype_a(dtype_a), dtype_b(dtype_b), dtype_c(dtype_c), + sparse(sparse) {} + int m, n, k; + DataType dtype_a, dtype_b, dtype_c; + bool sparse; + inline bool operator==(const WGMMAConfig &other) { + return m == other.m && n == other.n && k == other.k && + dtype_a == other.dtype_a && dtype_b == other.dtype_b && + dtype_c == other.dtype_c && sparse == other.sparse; + } +}; + +const WGMMAConfig valid_wgmma_configs[] = { + // Dense FP16 configurations + WGMMAConfig(64, 8, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 16, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 32, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 64, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 96, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 128, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 192, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 256, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + + // Dense FP16 to FP32 accumulation + WGMMAConfig(64, 8, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + + // Dense BFloat16 configurations + WGMMAConfig(64, 8, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + + // Dense TF32 configurations + WGMMAConfig(64, 8, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 24, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 40, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + + // Dense INT8 configurations + WGMMAConfig(64, 8, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 16, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 32, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 64, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 96, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 128, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 192, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 256, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + + // Dense UINT8 configurations + WGMMAConfig(64, 8, 32, DataType::kUInt8, DataType::kUInt8, DataType::kInt32, + false), + WGMMAConfig(64, 16, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 32, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 64, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 96, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 128, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 192, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 256, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + + // Dense INT4 configurations + WGMMAConfig(64, 8, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 16, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 32, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 64, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 96, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 128, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 192, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 256, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + + // Dense UINT4 configurations + WGMMAConfig(64, 8, 64, DataType::kUInt4, DataType::kUInt4, DataType::kInt32, + false), + WGMMAConfig(64, 16, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 32, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 64, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 96, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 128, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 192, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 256, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + + // Dense FP8 E4M3 configurations + WGMMAConfig(64, 8, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 8, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + + // Dense FP8 E5M2 configurations + WGMMAConfig(64, 8, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 8, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + + // Sparse FP16 configurations (k doubled for sparsity) + WGMMAConfig(64, 8, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 16, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 32, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 64, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 96, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 128, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 192, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 256, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + + // Sparse FP16 to FP32 accumulation + WGMMAConfig(64, 8, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + + // Sparse BFloat16 configurations + WGMMAConfig(64, 8, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + + // Sparse TF32 configurations + WGMMAConfig(64, 8, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + + // Sparse INT8 configurations + WGMMAConfig(64, 8, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 16, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 32, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 64, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 96, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 128, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 192, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 256, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + + // Sparse UINT8 configurations + WGMMAConfig(64, 8, 64, DataType::kUInt8, DataType::kUInt8, DataType::kInt32, + true), + WGMMAConfig(64, 16, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 32, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 64, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 96, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 128, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 192, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 256, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + + // Sparse INT4 configurations + WGMMAConfig(64, 8, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 16, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 32, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 64, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 96, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 128, 128, DataType::kInt4, DataType::kInt4, + DataType::kInt32, true), + WGMMAConfig(64, 192, 128, DataType::kInt4, DataType::kInt4, + DataType::kInt32, true), + WGMMAConfig(64, 256, 128, DataType::kInt4, DataType::kInt4, + DataType::kInt32, true), + + // Sparse UINT4 configurations + WGMMAConfig(64, 8, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 16, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 32, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 64, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 96, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 128, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 192, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 256, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + + // Sparse FP8 E4M3 configurations + WGMMAConfig(64, 8, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 8, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + + // Sparse FP8 E5M2 configurations + WGMMAConfig(64, 8, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 8, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true)}; + +/*! + * \brief Check whether the multiplicand data type and accumulator data type is + * valid for MMA computation. \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, + DataType dtype_c) { + std::string ab_not_match_err_str = "The multiplicands' data type " + + DTypeToString(dtype_a) + + DTypeToString(dtype_b) + " do not match."; + // check a and b + switch (dtype_a) { + case DataType::kBit1: + case DataType::kFloat16: + case DataType::kBFloat16: + case DataType::kFloat32: + case DataType::kTensorFloat32: + case DataType::kFloat64: + CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + break; + case DataType::kInt4: + case DataType::kUInt4: + CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) + << ab_not_match_err_str; + break; + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) + << ab_not_match_err_str; + break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_b == DataType::kFloat8_e4m3 || + dtype_b == DataType::kFloat8_e5m2) + << ab_not_match_err_str; + break; + default: + CHECK(false) << "Invalid multiplicand data types: " + << DTypeToString(dtype_a) << DTypeToString(dtype_b); + } + // check a,b and c + switch (dtype_a) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_c == DataType::kInt32) + << "For multiplicand data type " << DTypeToString(dtype_a) + << DTypeToString(dtype_b) << ", accumulator data type should be s32."; + break; + case DataType::kFloat16: + CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) + << "For multiplicand data type f16, accumulator data type should be " + "f16/f32."; + break; + case DataType::kBFloat16: + case DataType::kFloat32: + case DataType::kTensorFloat32: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can " + "only be f32."; + break; + case DataType::kFloat64: + CHECK(dtype_c == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be " + "f64."; + break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type e4m3/e5m2, accumulator data type can " + "only be f32."; + break; + default: + CHECK(false) << "Invalid multiplicand/accumulator data types: " + << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << DTypeToString(dtype_c) << "."; + } +} + +/*! + * \brief Check whether the given configuration is valid for MMA computation. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param layout_a The layout of multiplicand A (row/col). + * \param layout_b The layout of multiplicand B (row/col). + * \param dtype_a The data type of multiplicand A. + * \param dtype_b The data type of multiplicand B. + * \param dtype_c The data type of accumulator C. + * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" + * or ""(if it's not 1-bit MMA). \param sparse Whether it's Sparse MMA or not. + * \param saturate Whether saturate output or not. + */ +void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, + LayoutType layout_b, DataType dtype_a, + DataType dtype_b, DataType dtype_c, + const std::string &bit_op, bool sparse, + bool saturate) { + CHECK(bit_op == "xor" || bit_op == "and" || bit_op.empty()) + << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; + bool use_bit_op = !bit_op.empty(); + if (use_bit_op) { + CHECK(dtype_a == DataType::kBit1) + << "Bit operator is only compatible with 1-bit multiplicand."; + } + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); + if (saturate) { + CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || + dtype_a == DataType::kInt8 || dtype_a == DataType::kUInt8) + << "Output saturation only applicable to multiplicand type " + "s4/u4/s8/u8."; + } + + if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { + // Only MMA on m8n8k4 for fp16 supports customized layouts. + CHECK(layout_a == LayoutType::kRowMajor && + layout_b == LayoutType::kColumnMajor) + << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," + << LayoutTypeToString(layout_b) << "."; + } + + MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse); + bool match = false; + for (const MMAConfig &valid_config : valid_mma_configs) { + if (config == valid_config) { + match = true; + break; + } + } + CHECK(match) << "Cannot find matched MMA configurations."; +} + +void CheckWGMMAConfigValidity(int m, int n, int k, LayoutType layout_a, + LayoutType layout_b, DataType dtype_a, + DataType dtype_b, DataType dtype_c, bool sparse) { + // Same DataType Compatibility as MMA + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); + + // Check if configuration exists in valid_wgmma_configs + WGMMAConfig config(m, n, k, dtype_a, dtype_b, dtype_c, sparse); + bool match = false; + for (const WGMMAConfig &valid_config : valid_wgmma_configs) { + if (config == valid_config) { + match = true; + break; + } + } + CHECK(match) << "Cannot find matched WGMMA configurations for m " << m + << " n " << n << " k " << k << " dtype_a " + << DTypeToString(dtype_a) << " dtype_b " + << DTypeToString(dtype_b) << " dtype_c " + << DTypeToString(dtype_c) << " sparse " << sparse; +} +/*! + * \brief Fragment attributes + */ +class FragAttrs { +public: + explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type) + : reg_type(reg_type), size(size), ptr_type(ptr_type) {} + /*! \brief PTX register type */ + char reg_type; + /*! \brief Fragment size */ + uint32_t size; + /*! \brief Fragment pointer type */ + std::string ptr_type; +}; + +/*! + * \brief Fragment attributes of given data type. + */ +inline FragAttrs GetFragAttrs(DataType dtype) { + switch (dtype) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + case DataType::kBit16: + case DataType::kFloat16: // .f16x2 register + case DataType::kBFloat16: + case DataType::kTensorFloat32: + return FragAttrs('r', 32, "(unsigned *)"); + case DataType::kInt32: + return FragAttrs('r', 32, "(int *)"); + case DataType::kFloat32: + return FragAttrs('f', 32, "(float *)"); + case DataType::kFloat64: + return FragAttrs('d', 64, "(double *)"); + default: + ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; + return FragAttrs('\0', 0, ""); + } +} + +}; // namespace ptx + +/*! + * \brief Get the number of MMA computations for given shape and datatype. + */ +inline uint32_t GetNumMMAComputations(int m, int n, int k, + ptx::DataType dtype) { + if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) { + // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one. + return 4; + } else { + return 1; + } +} + +/*! + * \brief Return template string, input operands string and output operands + * string. \param m The M in mMnNkK of MMA instructions. \param n The N in + * mMnNkK of MMA instructions. \param k The K in mMnNkK of MMA instructions. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. + * \param sparse Whether it's Sparse MMA or not. + */ +inline std::tuple +GetMMAOperands(int m, int n, int k, ptx::DataType dtype_a, + ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse) { + std::stringstream templates, inputs, outputs; + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); + constexpr uint32_t warp_size = 32; + const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) / + frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_b = + (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, + num_operands_c = + (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_b; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}"; + // templates of metadata and sparse selector for sparse mma. + if (sparse) { + templates << ", %" << (arg_counter++) << ", F"; + } + + // generate inputs + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; + } + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type + << "(A))[" << i << "])"; + } + for (int i = 0; i < num_operands_b; ++i) { + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type + << "(B))[" << i << "])"; + } + for (int i = 0; i < num_operands_c; ++i) { + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type + << "(C))[" << i << "])"; + } + // input of metadata for sparse mma. + if (sparse) { + inputs << ", \"r\"(((unsigned *)(E))[0])"; + } + + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type + << "(D))[" << i << "])"; + } + return std::make_tuple(templates.str(), inputs.str(), outputs.str()); +} + +inline std::tuple +GetWGMMAOperands(int m, int n, int k, ptx::DataType dtype_a, + ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse, + bool a_is_shared) { + std::stringstream templates, inputs, outputs, predicate; + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); + constexpr uint32_t warp_size = 32; + const uint32_t threads = + 4 * warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) / + frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_c = + (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; + const bool support_ldmatrix_transposed = + ptx::DTypeBits(dtype_a) == 16 && ptx::DTypeBits(dtype_b) == 16; + const bool support_scale_input = + !ptx::DTypeIsInteger(dtype_a) || !ptx::DTypeIsInteger(dtype_b); + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + if (!a_is_shared) { + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}"; + } else { + templates << "}, %" << arg_counter++; + } + + // desc_b + templates << ", " + << "%" << arg_counter++; + + // scale_out + predicate << "%" << arg_counter++; + templates << ", " + << "p"; + + // scale_in_a + if (support_scale_input) { + templates << ", " + << "%" << arg_counter++; + // scale_in_b + templates << ", " + << "%" << arg_counter++; + } + if (support_ldmatrix_transposed) { + if (a_is_shared) { + // trans_a + templates << ", " + << "%" << arg_counter++; + } + // trans_b + templates << ", " + << "%" << arg_counter++; + } + // templates of metadata and sparse selector for sparse mma. + if (sparse) { + LOG(FATAL) << "Sparse WGMMA is not supported yet."; + } + + // generate inputs + if (a_is_shared) { + inputs << "\"l\"(uint64_t((desc_a) + (A_offset)))"; + } else { + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; + } + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type + << "((A)))[" << i << "])"; + } + } + inputs << ", \"l\"(uint64_t((desc_b) + (B_offset)))"; + + // input of metadata for sparse mma. + if (sparse) { + inputs << ", \"r\"(((unsigned *)((E)))[0])"; + } + + inputs << ", \"r\"(int32_t((scale_out)))"; + // scale_in_a + if (support_scale_input) { + inputs << ", \"n\"(int32_t((scale_in_a)))"; + // scale_in_b + inputs << ", \"n\"(int32_t((scale_in_b)))"; + } + if (support_ldmatrix_transposed) { + if (a_is_shared) { + // trans_a + inputs << ", \"n\"(int32_t((trans_a)))"; + } + // trans_b + inputs << ", \"n\"(int32_t((trans_b)))"; + } + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << "\"+" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type + << "((D)))[" << i << "])"; + } + + return std::make_tuple(templates.str(), inputs.str(), outputs.str(), + predicate.str()); +} + +std::string +PrintMMAAssembly(const std::string &shape, const std::string &A_layout, + const std::string &B_layout, const std::string &A_dtype, + const std::string &B_dtype, const std::string &C_dtype, + const std::string &a_ptr, const std::string &a_elem_offset, + const std::string &b_ptr, const std::string &b_elem_offset, + const std::string &c_ptr, const std::string &c_elem_offset, + const std::string &metadata, + const std::string &metadata_offset, + const std::string &sparsity_selector, + const std::string &bit_op, bool sparse, bool saturate) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), + dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + if (dtype_a == ptx::DataType::kFloat32) { + dtype_a = ptx::DataType::kTensorFloat32; + } + if (dtype_b == ptx::DataType::kFloat32) { + dtype_b = ptx::DataType::kTensorFloat32; + } + ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), + layout_b = ptx::LayoutTypeFromString(B_layout); + auto [m, n, k] = ptx::ParseMMAShape(shape); + CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, + bit_op, sparse, saturate); + std::string asm_code = R"( + { + __asm__ __volatile__( + "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}" + "{templates};\n" + : {outputs} + : {inputs}); + } +)"; + auto [templates_str, inputs_str, outputs_str] = + GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); + + // replace patterns + Replacer replacer; + replacer.register_rule("{.sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{.shape}", "." + shape); + replacer.register_rule("{.saturate}", saturate ? ".satfinite" : ""); + replacer.register_rule("{.alayout}", "." + A_layout); + replacer.register_rule("{.blayout}", "." + B_layout); + replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{.bitop}", + bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + replacer.register_rule("A", a_ptr + " + " + a_elem_offset); + replacer.register_rule("B", b_ptr + " + " + b_elem_offset); + replacer.register_rule("C", c_ptr + " + " + c_elem_offset); + replacer.register_rule("D", c_ptr + " + " + c_elem_offset); + replacer.register_rule("E", metadata + " + " + metadata_offset); + replacer.register_rule("F", sparsity_selector); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string +PrintWGMMAAssembly(const std::string &shape, const bool &a_is_k_major, + const bool &b_is_k_major, const std::string &A_dtype, + const std::string &B_dtype, const std::string &C_dtype, + const std::string &a_desc, const std::string &A_offset, + const std::string &b_desc, const std::string &B_offset, + const std::string &c_ptr, const std::string &c_offset, + const bool &scale_out, const bool &scale_in_a, + const bool &scale_in_b, const bool &a_is_shared, + const std::string &metadata, + const std::string &metadata_offset, + const std::string &sparsity_selector, bool sparse) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), + dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + if (dtype_a == ptx::DataType::kFloat32) { + dtype_a = ptx::DataType::kTensorFloat32; + } + if (dtype_b == ptx::DataType::kFloat32) { + dtype_b = ptx::DataType::kTensorFloat32; + } + + ptx::LayoutType layout_a = ptx::LayoutTypeFromBool(!a_is_k_major), + layout_b = ptx::LayoutTypeFromBool(b_is_k_major); + auto [m, n, k] = ptx::ParseMMAShape(shape); + CheckWGMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, + dtype_c, sparse); + std::string asm_code = R"( + { + __asm__ __volatile__( + "{.reg .pred p;\n" + "setp.ne.b32 p, {predicate}, 0;\n" + "wgmma.mma_async{.sparse}.sync.aligned{.shape}{.dtype}{.atype}{.btype}" + "{templates};\n}" + : {outputs} + : {inputs}); + } +)"; + auto [templates_str, inputs_str, outputs_str, predicate_str] = + GetWGMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse, a_is_shared); + + // replace patterns + Replacer replacer; + replacer.register_rule("{.sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{.shape}", "." + shape); + replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + replacer.register_rule("{predicate}", predicate_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + if (a_is_shared) { + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + } else { + replacer.register_rule("(A)", a_desc + " + " + A_offset); + } + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ptr + " + " + c_offset); + replacer.register_rule("(D)", c_ptr + " + " + c_offset); + replacer.register_rule("(E)", metadata + " + " + metadata_offset); + replacer.register_rule("(F)", sparsity_selector); + replacer.register_rule("(scale_out)", scale_out ? "1" : "0"); + replacer.register_rule("(scale_in_a)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scale_in_b)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(trans_a)", a_is_k_major ? "0" : "1"); + replacer.register_rule("(trans_b)", b_is_k_major ? "0" : "1"); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +inline std::tuple +GetLoadMatrixOperands(int num, const std::string &local_ptr, + const std::string &local_elem_offset) { + std::stringstream templates, outputs; + int arg_counter = 0; + // generate templates + templates << "{%" << arg_counter++; + for (int i = 1; i < num; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, [%" << arg_counter++ << "]"; + // generate outputs + std::string ptr_type = "(unsigned *)"; + for (int i = 0; i < num; ++i) { + if (i != 0) { + outputs << ", "; + } + outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " + << local_elem_offset << "))[" << i << "])"; + } + return std::make_tuple(templates.str(), outputs.str()); +} + +std::string PrintLoadMatrixAssembly(bool trans, int num, + const std::string &type, + const std::string &local_ptr, + const std::string &local_elem_offset, + const std::string &smem_ptr, + const std::string &smem_elem_offset) { + CHECK(num == 1 || num == 2 || num == 4) + << "ldmatrix only accept loading 1/2/4 matrices."; + ptx::DataType data_type = ptx::DTypeFromString(type); + CHECK(data_type == ptx::DataType::kBit16) + << "ldmatrix only accept matrix with type .b16."; + std::string asm_code = R"( + { + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); + __asm__ __volatile__( + "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}" + "{templates};\n" + : {outputs} + : "r"(addr) + ); + } +)"; + auto [templates_str, outputs_str] = + GetLoadMatrixOperands(num, local_ptr, local_elem_offset); + + Replacer replacer; + replacer.register_rule("{.shape}", ".m8n8"); + replacer.register_rule("{.num}", ".x" + std::to_string(num)); + replacer.register_rule("{.trans}", trans ? ".trans" : ""); + replacer.register_rule("{.ss}", ".shared"); + replacer.register_rule("{.type}", ptx::DTypeToString(data_type)); + replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintCpAsyncAssembly(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes) { + std::string asm_code = R"( + { + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", + shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", + global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintPredicatedCpAsyncAssembly( + const std::string &shared_ptr, const std::string &shared_elem_offset, + const std::string &global_ptr, const std::string &global_elem_offset, + const std::string &bytes, const std::string &predicate_value) { + CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || + bytes == "2" || bytes == "1") + << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; + std::string predicated_asm_code = R"( + { + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); + int pred_guard = (int){pred_guard}; + __asm__ __volatile__( + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" + #endif + " @!p {store_shared};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg} + ); + } +)"; + auto [store_shared, nopreg] = [](const std::string &bytes) { + if (bytes == "16") + return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}", + "\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)"); + else if (bytes == "12") + return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", + "\"r\"(0), \"r\"(0), \"r\"(0)"); + else if (bytes == "8") + return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", + "\"r\"(0), \"r\"(0)"); + else if (bytes == "4") + return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "2") + return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "1") + return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)"); + else + return std::make_tuple("", ""); + }(bytes); + + Replacer replacer; + replacer.register_rule("{smem_addr}", + shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", + global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{store_shared}", store_shared); + replacer.register_rule("{nopreg}", nopreg); + replacer.register_rule("{pred_guard}", predicate_value); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes, + const std::string &barrier) { + std::string asm_code = R"( + { + unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" + :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int) + : "memory" + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{smem_addr}", + shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", + global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{barrier}", "&" + barrier); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintCpAsyncBarrierAsm(const std::string &barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "cp.async.mbarrier.arrive.shared.b64 [%0];" + :: "r" (barrier_addr_int) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintInitBarrierThreadCountAsm(const std::string &barrier, + const std::string &thread_count) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + int thread_count = {thread_count}; + __asm__ __volatile__( + "mbarrier.init.shared.b64 [%0], %1;" + :: "r"(barrier_addr_int), "r"(thread_count) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + replacer.register_rule("{thread_count}", thread_count); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintArriveBarrierAsm(const std::string &barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }" + :: "r"(barrier_addr_int) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, + const std::string &byte_count) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + int byte_count = {byte_count}; + __asm__ __volatile__( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + :: "r"(barrier_addr_int), "r"(byte_count) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + replacer.register_rule("{byte_count}", byte_count); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintWaitBarrierAsm(const std::string &barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + constexpr int phase_bit = 0; + __asm__ __volatile__( + "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }" + :: "r"(barrier_addr_int), "r"(phase_bit) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", "&" + barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string GetMMARegisterType(const ptx::DataType &dtype) { + switch (dtype) { + case ptx::DataType::kInt32: + return "unsigned"; + case ptx::DataType::kUInt32: + return "unsigned"; + case ptx::DataType::kFloat32: + return "float"; + case ptx::DataType::kFloat64: + return "double"; + default: + return "unsigned"; + } +} + +} // namespace codegen +} // namespace tvm::tl diff --git a/tilelang/original/src/target/ptx.h b/tilelang/original/src/target/ptx.h new file mode 100644 index 0000000000000000000000000000000000000000..566cded6fa377ee4e471a2271c952f93023dd10c --- /dev/null +++ b/tilelang/original/src/target/ptx.h @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx.h + * \brief Code generation with inlined PTX code. + */ +#ifndef TVM_TL_TARGET_SOURCE_PTX_H_ +#define TVM_TL_TARGET_SOURCE_PTX_H_ + +#include + +#include +#include + +namespace tvm::tl { +namespace codegen { + +namespace ptx { + +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 +}; + +/*! + * \brief Print ptx data type from string. + */ +DataType DTypeFromString(const std::string str); + +/*! + * \brief Print ptx data type from enum. + */ +std::string DTypeEnumToString(const DataType &dtype); + +/*! + * \brief Print ptx data type from string. + */ +std::string DTypeEnumToString(const std::string &dtype); + +/*! + * \brief Parse MMA shape from string. + */ +std::tuple ParseMMAShape(const std::string &str); +} // namespace ptx + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { +public: + void register_rule(const std::string &pattern, + const std::string &replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto &&rule : _rules) { + auto [pattern, replacement] = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } + } + return str; + } + void empty_rules() { _rules.clear(); } + +private: + std::vector> _rules; +}; + +/*! + * \brief Print MMA assembly string given parameters. + * \param shape The shape string mMnNkK + * \param A_layout The layout of multiplicand A, can be either "row" or "col". + * \param B_layout The layout of multiplicand B, can be either "row" or "col". + * \param A_dtype The data type of multiplicand A. + * \param B_dtype The data type of multiplicand B. + * \param C_dtype The data type of multiplicand C. + * \param a_ptr Pointer to buffer A. + * \param a_offset The offset of element in A. + * \param b_ptr Pointer to buffer B. + * \param b_offset The offset of element in B. + * \param c_ptr Pointer to buffer C. + * \param c_offset The offset of element in C. + * \param metadata Pointer to metadata buffer (only used for sparse mma). + * \param metadata_offset The offset of element in metadata. + * \param sparsity_selector The sparsity selector in sparse mma. + * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or + * "and". \param sparse Whether it's sparse mma or not. \param saturate Whether + * saturate output or not. + */ +std::string +PrintMMAAssembly(const std::string &shape, const std::string &A_layout, + const std::string &B_layout, const std::string &A_dtype, + const std::string &B_dtype, const std::string &C_dtype, + const std::string &a_ptr, const std::string &a_offset, + const std::string &b_ptr, const std::string &b_offset, + const std::string &c_ptr, const std::string &c_offset, + const std::string &metadata, + const std::string &metadata_offset, + const std::string &sparsity_selector, + const std::string &bit_op, bool sparse, bool saturate); + +/*! + * \brief Print WGMMA assembly string given parameters. + * \param shape The shape string mMnNkK + * \param A_layout The layout of multiplicand A, can be either "row" or "col". + * \param B_layout The layout of multiplicand B, can be either "row" or "col". + * \param A_dtype The data type of multiplicand A. + * \param B_dtype The data type of multiplicand B. + * \param C_dtype The data type of multiplicand C. + */ +std::string +PrintWGMMAAssembly(const std::string &shape, const bool &a_is_k_major, + const bool &b_is_k_major, const std::string &A_dtype, + const std::string &B_dtype, const std::string &C_dtype, + const std::string &a_desc, const std::string &A_offset, + const std::string &b_desc, const std::string &B_offset, + const std::string &c_ptr, const std::string &c_offset, + const bool &scale_out, const bool &scale_in_a, + const bool &scale_in_b, const bool &a_is_shared, + const std::string &metadata, + const std::string &metadata_offset, + const std::string &sparsity_selector, bool sparse); + +/*! + * \brief Print ldmatrix assembly string given parameters. + * \param trans: whether the matrix is loaded in column major format or not. + * \param num: number of matrices to load. + * \param type: The data type in the matrix, .b16 is the only accepted data + * type. \param local_ptr: pointer to local buffer. \param local_elem_offset: + * The offset of the element to store in the local buffer. \param smem_ptr: + * pointer to the shared memory buffer to load. \param smem_elem_offset: The + * offset of the start element of the row to load in shared memory. + */ +std::string PrintLoadMatrixAssembly(bool trans, int num, + const std::string &type, + const std::string &local_ptr, + const std::string &local_elem_offset, + const std::string &smem_ptr, + const std::string &smem_elem_offset); + +/*! + * \brief Print ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + */ +std::string PrintCpAsyncAssembly(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes); + +/*! + * \brief Print predicated ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + * \param predicate_value: The value of predicate `@p`. + */ +std::string PrintPredicatedCpAsyncAssembly( + const std::string &shared_ptr, const std::string &shared_elem_offset, + const std::string &global_ptr, const std::string &global_elem_offset, + const std::string &bytes, const std::string &predicate_value); + +/*! + * \brief Print ptx async copy from global to shared memory using cp.async.bulk + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy. + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr, + const std::string &shared_elem_offset, + const std::string &global_ptr, + const std::string &global_elem_offset, + const std::string &bytes, + const std::string &barrier); + +/*! + * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintCpAsyncBarrierAsm(const std::string &barrier); + +/*! + * \brief Print ptx barrier initialization of thread count using mbarrier.init + * \param barrier: The name of the barrier in shared memory. + * \param thread_count: The number of threads expected to arrive at the barrier. + */ +std::string PrintInitBarrierThreadCountAsm(const std::string &barrier, + const std::string &thread_count); + +/*! + * \brief Print ptx barrier arrival using mbarrier.arrive + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintArriveBarrierAsm(const std::string &barrier); + +/*! + * \brief Print ptx barrier arrival with expect tx operation using + * mbarrier.arrive.expect_tx \param barrier: The name of the barrier in shared + * memory. \param byte_count: Increases the tx count of the mbarrier object to + * track completion of additional async transactions. + */ +std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, + const std::string &byte_count); + +/*! + * \brief Print ptx barrier wait using mbarrier.try_wait + * \param barrier: The name of the barrier in shared memory. + */ +std::string PrintWaitBarrierAsm(const std::string &barrier); + +/*! + * \brief Return the register-level C++ type used by MMA fragments. + */ +std::string GetMMARegisterType(const ptx::DataType &dtype); + +} // namespace codegen +} // namespace tvm::tl + +#endif // TVM_TL_TARGET_SOURCE_PTX_H_ diff --git a/tilelang/original/src/target/rt_mod_cpp.cc b/tilelang/original/src/target/rt_mod_cpp.cc new file mode 100644 index 0000000000000000000000000000000000000000..10e3d57b6a23a0eba84f28f5c5fe2a608ea885d8 --- /dev/null +++ b/tilelang/original/src/target/rt_mod_cpp.cc @@ -0,0 +1,79 @@ +#include "codegen_cpp.h" +#include +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace codegen { + +ffi::Module BuildCPPHost(IRModule mod, Target target) { + bool output_ssa = false; + bool emit_asserts = false; + bool emit_fwd_func_decl = true; + + std::unordered_set devices; + if (mod->GetAttr>("device_contexts") != nullptr) { + Map device_contexts = + mod->GetAttr>("device_contexts").value(); + for (auto const &context : device_contexts) { + devices.insert(context.second.data()); + } + } + + CodeGenTileLangCPP cg; + cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); + cg.SetConstantsByteAlignment( + target->GetAttr("constants-byte-alignment").value_or(16)); + + auto is_aot_executor_fn = [](const PrimFunc &func) -> bool { + return func->GetAttr("runner_function", Bool(false)).value(); + }; + + std::vector> funcs; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) + << "CodegenCHost: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + funcs.push_back({gvar, prim_func}); + } + + // Sort functions + auto sort_key = [&is_aot_executor_fn](const auto &kv) { + return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; + }; + std::sort(funcs.begin(), funcs.end(), + [&sort_key](const auto &kv_a, const auto &kv_b) { + return sort_key(kv_a) < sort_key(kv_b); + }); + + // Declare all functions first. This ensures that all functions, + // including the __tvm_main__ used in AOT, have access to forward + // declarations of other functions in the IRModule. + for (const auto &[gvar, prim_func] : funcs) { + cg.DeclareFunction(gvar, prim_func); + } + + // Codegen all functions. Passing emit_fwd_func_decl=true adds a + // forward declaration for any `builtin::call_extern`, based on the + // arguments provided to it. + for (const auto &[gvar, prim_func] : funcs) { + cg.AddFunction(prim_func); + } + + if (target->GetAttr("system-lib").value_or(Bool(false))) { + ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") + << "c target only supports generating C runtime SystemLibs"; + } + + std::string code = cg.Finish(); + return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost); +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/rt_mod_cuda.cc b/tilelang/original/src/target/rt_mod_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5e9b29908d8cf55e2636e454861c6689c5d8234 --- /dev/null +++ b/tilelang/original/src/target/rt_mod_cuda.cc @@ -0,0 +1,114 @@ +#include "codegen_cuda.h" +#include "runtime/cuda/cuda_module.h" +#include "runtime/pack_args.h" +#include +#include + +namespace tvm { +namespace codegen { + +static std::unordered_map +ExtractFuncInfo(const IRModule &mod) { + std::unordered_map fmap; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); + + runtime::FunctionInfo info; + for (size_t i = 0; i < f->params.size(); ++i) { + if (f->params[i]->dtype.is_handle()) { + auto ptr = f->params[i]->type_annotation.as(); + if (ptr && ptr->storage_scope == "grid_constant") { + info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1)); + continue; + } + } + DataType dtype = f->params[i].dtype(); + // Device runtime cannot directly take bool arguments, map to int32. + if (dtype.is_bool()) + dtype = DataType::Int(32); + info.arg_types.push_back(dtype); + } + if (auto opt = f->GetAttr>( + tir::attr::kKernelLaunchParams)) { + for (const auto &tag : opt.value()) { + info.launch_param_tags.push_back(tag); + } + } + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + fmap[static_cast(global_symbol.value())] = info; + } + return fmap; +} + +ffi::Module BuildTileLangCUDA(IRModule mod, Target target) { + bool output_ssa = false; + CodeGenTileLangCUDA cg; + cg.Init(output_ssa); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangCUDA: Can only take PrimFunc"; + auto gvar = Downcast(kv.first); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); + cg.AddFunction(gvar, f); + } + + std::string code = cg.Finish(); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) { + code = (*f)(code, target).cast(); + } + std::string fmt = "ptx"; + std::string ptx; + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) { + // Fetch current pass context config and pass into the compile callback + tvm::transform::PassContext pass_ctx = + tvm::transform::PassContext::Current(); + ptx = (*f)(code, target, pass_ctx->config).cast(); + if (ptx[0] != '/') + fmt = "cubin"; + } else { + ICHECK(0); + } + return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); +} + +ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { + bool output_ssa = false; + CodeGenTileLangCUDA cg; + cg.Init(output_ssa); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangCUDA: Can only take PrimFunc"; + auto gvar = Downcast(kv.first); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); + cg.AddFunction(gvar, f); + } + + std::string code = cg.Finish(); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) { + code = (*f)(code, target).cast(); + } + return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.build.tilelang_cuda", BuildTileLangCUDA) + .def("target.build.tilelang_cuda_without_compile", + BuildTileLangCUDAWithoutCompile); +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/rt_mod_cutedsl.cc b/tilelang/original/src/target/rt_mod_cutedsl.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2b6d05d110d6c150f5f47b897c58e1873767894 --- /dev/null +++ b/tilelang/original/src/target/rt_mod_cutedsl.cc @@ -0,0 +1,69 @@ +#include "codegen_cutedsl.h" +#include "runtime/cuda/cuda_module.h" +#include "runtime/pack_args.h" +#include + +namespace tvm { +namespace codegen { + +static std::unordered_map +ExtractFuncInfo(const IRModule &mod) { + std::unordered_map fmap; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); + + runtime::FunctionInfo info; + for (size_t i = 0; i < f->params.size(); ++i) { + if (f->params[i]->dtype.is_handle()) { + auto ptr = f->params[i]->type_annotation.as(); + if (ptr && ptr->storage_scope == "grid_constant") { + info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1)); + continue; + } + } + info.arg_types.push_back(f->params[i].dtype()); + } + if (auto opt = f->GetAttr>( + tir::attr::kKernelLaunchParams)) { + for (const auto &tag : opt.value()) { + info.launch_param_tags.push_back(tag); + } + } + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + fmap[static_cast(global_symbol.value())] = info; + } + return fmap; +} + +ffi::Module BuildTileLangCuTeDSLWithoutCompile(IRModule mod, Target target) { + CodeGenTileLangCuTeDSL cg; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangCuTeDSL: Can only take PrimFunc"; + auto gvar = Downcast(kv.first); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); + cg.AddFunction(gvar, f); + } + + std::string code = cg.Finish(); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cutedsl_postproc")) { + code = (*f)(code, target).cast(); + } + return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_cutedsl_without_compile", + BuildTileLangCuTeDSLWithoutCompile); +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/rt_mod_hip.cc b/tilelang/original/src/target/rt_mod_hip.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e5c689c6e1f81a4723ab25b642498225f9e682a --- /dev/null +++ b/tilelang/original/src/target/rt_mod_hip.cc @@ -0,0 +1,127 @@ +#if defined(__linux__) +#include +#include +#endif + +#include +#include + +#include "codegen_hip.h" +#include "runtime/rocm/rocm_module.h" +#include + +#ifndef kTVMGridConstant +#define kTVMGridConstant 130 +#endif + +namespace tvm { +namespace codegen { + +static std::unordered_map +ExtractFuncInfo(const IRModule &mod) { + std::unordered_map fmap; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); + + runtime::FunctionInfo info; + for (size_t i = 0; i < f->params.size(); ++i) { + if (f->params[i]->dtype.is_handle()) { + auto ptr = f->params[i]->type_annotation.as(); + if (ptr && ptr->storage_scope == "grid_constant") { + info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1)); + continue; + } + } + DataType dtype = f->params[i].dtype(); + // Device runtime cannot directly take bool arguments, map to int32. + if (dtype.is_bool()) + dtype = DataType::Int(32); + info.arg_types.push_back(dtype); + } + if (auto opt = f->GetAttr>( + tir::attr::kKernelLaunchParams)) { + for (const auto &tag : opt.value()) { + info.launch_param_tags.push_back(tag); + } + } + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + fmap[static_cast(global_symbol.value())] = info; + } + return fmap; +} + +ffi::Module BuildTileLangHIP(IRModule mod, Target target) { + bool output_ssa = false; + CodeGenTileLangHIP cg; + cg.Init(output_ssa); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangHIP: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); + cg.AddFunction(f); + } + + std::string code = cg.Finish(); + + // Use the new FFI API to get registered functions + using ffi::Function; + if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) { + code = (*f)(code, target).cast(); + } + + std::string fmt = "ptx"; + std::string ptx; + + if (auto f = Function::GetGlobal("tilelang_callback_hip_compile")) { + ptx = (*f)(code, target).cast(); + if (ptx[0] != '/') + fmt = "hsaco"; + } else { + ICHECK(false) << "tilelang_callback_hip_compile is not set"; + } + + return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); +} + +ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { + bool output_ssa = false; + CodeGenTileLangHIP cg; + cg.Init(output_ssa); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangHIP: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); + cg.AddFunction(f); + } + + std::string code = cg.Finish(); + + // Use the new FFI API to get registered functions + using ffi::Function; + if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) { + code = (*f)(code, target).cast(); + } + + return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code, + std::string()); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.build.tilelang_hip", BuildTileLangHIP) + .def("target.build.tilelang_hip_without_compile", + BuildTileLangHIPWithoutCompile); +} + +} // namespace codegen +} // namespace tvm diff --git a/tilelang/original/src/target/rt_mod_metal.cc b/tilelang/original/src/target/rt_mod_metal.cc new file mode 100644 index 0000000000000000000000000000000000000000..2881075c0f0d6c38ee32608a8f388ae1ecbf8371 --- /dev/null +++ b/tilelang/original/src/target/rt_mod_metal.cc @@ -0,0 +1,3 @@ +// Currently mps backend use the codegen from tvm without modification. +// But in the future we're likely to add functions on top of that. +// Added an empty file for now. diff --git a/tilelang/original/src/target/utils.cc b/tilelang/original/src/target/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..910b128e6d9eff4b8f61cdd33158409d82cdb5d6 --- /dev/null +++ b/tilelang/original/src/target/utils.cc @@ -0,0 +1,180 @@ +/*! + * \file tl/target/utils.cc + * \brief helper functions for target attributes. + */ + +#include "utils.h" + +#include "../support/ffi_aliases.h" +#include + +namespace tvm { +namespace tl { + +bool TargetIsCuda(Target target) { + return target->GetTargetDeviceType() == kDLCUDA; +} +bool TargetIsRocm(Target target) { + return target->GetTargetDeviceType() == kDLROCM; +} + +int GetArchInt(Target target) { + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); + const std::string arch_str = s.value(); + ICHECK(arch_str.size() >= 3); + ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0) + << "arch string must start with sm_"; + return std::stoi(arch_str.substr(3)); +} + +bool TargetIsVolta(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 70 && arch < 75; +} + +bool TargetIsTuring(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 75 && arch < 80; +} + +bool TargetIsAmpere(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 80 && arch < 90; +} + +bool TargetIsHopper(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 90 && arch < 100; +} + +bool TargetIsSm100(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 100 & arch <= 110; +} + +bool TargetIsSM120(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 120 && arch < 130; +} + +bool TargetIsCDNA(Target target) { + if (!TargetIsRocm(target)) + return false; + if (target->attrs.count("mcpu")) { + std::string mcpu = Downcast(target->attrs.at("mcpu")); + // if mcpu start with "gfx9", it is CDNA + return mcpu.find("gfx9") == 0; + } + return false; +} + +bool TargetIsDCU(Target target) { + if (!TargetIsRocm(target)) + return false; + if (target->attrs.count("mcpu")) { + std::string mcpu = Downcast(target->attrs.at("mcpu")); + // if mcpu start with "gfx936", it is DCU + return mcpu.find("gfx936") == 0; + } + return false; +} + +bool TargetHasAsyncCopy(Target target) { + if (TargetIsCuda(target)) { + int arch = GetArchInt(target); + return arch >= 80; + } else if (TargetIsCDNA(target)) { + if (target->attrs.count("mcpu")) { + std::string mcpu = Downcast(target->attrs.at("mcpu")); + if (mcpu.rfind("gfx9", 0) == 0) { + int gfx_version = std::stoi(mcpu.substr(3, 2)); + return gfx_version >= 94; + } + return false; + } else { + return false; + } + } + + return false; +} +bool TargetHasLdmatrix(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 75; +} + +bool TargetHasStmatrix(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 90; +} + +bool TargetHasTmem(Target target) { + if (!TargetIsCuda(target)) + return false; + return TargetIsSm100(target); +} + +bool TargetHasBulkCopy(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 90; +} + +int TargetGetWarpSize(Target target) { + int res = 32; + if (TargetIsCDNA(target)) + res = 64; + return res; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tl.TargetIsCuda", + [](Target target) { return TargetIsCuda(target); }) + .def("tl.TargetIsRocm", + [](Target target) { return TargetIsRocm(target); }) + .def("tl.TargetIsVolta", + [](Target target) { return TargetIsVolta(target); }) + .def("tl.TargetIsTuring", + [](Target target) { return TargetIsTuring(target); }) + .def("tl.TargetIsAmpere", + [](Target target) { return TargetIsAmpere(target); }) + .def("tl.TargetIsHopper", + [](Target target) { return TargetIsHopper(target); }) + .def("tl.TargetIsSM120", + [](Target target) { return TargetIsSM120(target); }) + .def("tl.TargetIsCDNA", + [](Target target) { return TargetIsCDNA(target); }) + .def("tl.TargetHasAsyncCopy", + [](Target target) { return TargetHasAsyncCopy(target); }) + .def("tl.TargetHasLdmatrix", + [](Target target) { return TargetHasLdmatrix(target); }) + .def("tl.TargetHasStmatrix", + [](Target target) { return TargetHasStmatrix(target); }) + .def("tl.TargetHasBulkCopy", + [](Target target) { return TargetHasBulkCopy(target); }) + .def("tl.TargetGetWarpSize", + [](Target target) { return TargetGetWarpSize(target); }); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/target/utils.h b/tilelang/original/src/target/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..74746ae680d866380e9785ad47054d2b6b840f51 --- /dev/null +++ b/tilelang/original/src/target/utils.h @@ -0,0 +1,37 @@ +/*! + * \file tl/target/utils.h + * \brief helper functions for target attributes. + * + */ + +#ifndef TVM_TL_TARGET_UTILS_H_ +#define TVM_TL_TARGET_UTILS_H_ + +#include + +namespace tvm { +namespace tl { + +bool TargetIsCuda(Target target); +bool TargetIsRocm(Target target); + +bool TargetIsVolta(Target target); +bool TargetIsTuring(Target target); +bool TargetIsAmpere(Target target); +bool TargetIsHopper(Target target); +bool TargetIsSm100(Target target); +bool TargetIsSM120(Target target); +bool TargetIsCDNA(Target target); +bool TargetIsDCU(Target target); + +bool TargetHasAsyncCopy(Target target); +bool TargetHasLdmatrix(Target target); +bool TargetHasStmatrix(Target target); +bool TargetHasTmem(Target target); +bool TargetHasBulkCopy(Target target); +int TargetGetWarpSize(Target target); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TARGET_UTILS_H_ diff --git a/tilelang/original/src/tl_templates/cpp/common.h b/tilelang/original/src/tl_templates/cpp/common.h new file mode 100644 index 0000000000000000000000000000000000000000..0ce6580d34b338bbf12b94d51d5092c897fc872e --- /dev/null +++ b/tilelang/original/src/tl_templates/cpp/common.h @@ -0,0 +1,8 @@ +#pragma once + +#include "half.hpp" +#include +#include + +using half_float::half; +// Not Implemented \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/cpp/gemm.h b/tilelang/original/src/tl_templates/cpp/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..1d8fbb7e2a9d460cef00585313abd89368858824 --- /dev/null +++ b/tilelang/original/src/tl_templates/cpp/gemm.h @@ -0,0 +1,3 @@ +#pragma once + +// Not Implemented diff --git a/tilelang/original/src/tl_templates/cpp/half.hpp b/tilelang/original/src/tl_templates/cpp/half.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5410e7572b750e8dd1c51a68e6429b4d3c930ea7 --- /dev/null +++ b/tilelang/original/src/tl_templates/cpp/half.hpp @@ -0,0 +1,5572 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2025 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Version 2.2.1 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) +#define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) +#define HALF_ICC_VERSION __ICC +#elif defined(__ICL) +#define HALF_ICC_VERSION __ICL +#else +#define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang +#if __has_feature(cxx_static_assert) && \ + !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if __has_feature(cxx_user_literals) && \ + !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ + !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif defined(__GNUC__) // gcc +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#define HALF_POP_WARNINGS 1 +#pragma warning(push) +#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, + // negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#ifndef HALF_ENABLE_CPP11_CFENV +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#elif defined(__GLIBCXX__) // libstdc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifdef __clang__ +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#else +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || \ + defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ + defined(HALF_ERRHANDLING_THROW_OVERFLOW) || \ + defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ + defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING \ + (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || \ + HALF_ERRHANDLING_FENV || HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING +#define HALF_UNUSED_NOERR(name) name +#else +#define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR +#define HALF_CONSTEXPR constexpr +#define HALF_CONSTEXPR_CONST constexpr +#if HALF_ERRHANDLING +#define HALF_CONSTEXPR_NOERR +#else +#define HALF_CONSTEXPR_NOERR constexpr +#endif +#else +#define HALF_CONSTEXPR +#define HALF_CONSTEXPR_CONST const +#define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT +#define HALF_NOEXCEPT noexcept +#define HALF_NOTHROW noexcept +#else +#define HALF_NOEXCEPT +#define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL +#define HALF_THREAD_LOCAL thread_local +#else +#define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS +#include +#endif +#if HALF_ENABLE_CPP11_CSTDINT +#include +#endif +#if HALF_ERRHANDLING_ERRNO +#include +#endif +#if HALF_ENABLE_CPP11_CFENV +#include +#endif +#if HALF_ENABLE_CPP11_HASH +#include +#endif + +#ifndef HALF_ENABLE_F16C_INTRINSICS +/// Enable F16C instruction set intrinsics. +/// Defining this to 1 enables the use of [F16C compiler +/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between +/// half-precision and single-precision values which may result in improved +/// performance. This will not perform additional checks for support of the F16C +/// instruction set, so an appropriate target platform is required when enabling +/// this feature. +/// +/// Unless predefined it will be enabled automatically when the `__F16C__` +/// symbol is defined, which some compilers do on supporting platforms. +#define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif +#if HALF_ENABLE_F16C_INTRINSICS +#include +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` +/// or `long double`) to override the internal half-precision implementation to +/// use this type for computing arithmetic operations and mathematical functions +/// (if available). This can result in improved performance for arithmetic +/// operators and mathematical functions but might cause results to deviate from +/// the specified half-precision rounding mode and inhibits proper detection of +/// half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise +/// internal floating-point exception flags according to the IEEE 754 standard. +/// These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate +/// floating-point exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. +/// Specifically this will propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, +/// overflow and underflow errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact +/// errors won't be propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate +/// floating-point exceptions to the built-in single- and double-precision +/// implementation's exception flags using the [C++11 floating-point environment +/// control](https://en.cppreference.com/w/cpp/numeric/fenv) from ``. +/// However, this does not work in reverse and single- or double-precision +/// exceptions will not raise the corresponding half-precision exception flags, +/// nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values +/// to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) +/// with the specified message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values +/// to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) +/// with the specified message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values +/// to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) +/// with the specified message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values +/// to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) +/// with the specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with +/// the specified message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to also raise inexact +/// exceptions. These will be raised after any possible handling of the +/// underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to also raise inexact +/// exceptions. These will be raised after any possible handling of the +/// underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) +/// exceptions to be raised *only* when the result is inexact, while if disabled +/// bare underflow errors will be raised for *any* (possibly exact) subnormal +/// result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between +/// [half](\ref half_float::half)s and more precise types (unless using +/// half_cast() and specifying the rounding mode directly) as well as in +/// arithmetic operations and mathematical functions. It can be redefined +/// (before including half.hpp) to one of the standard rounding modes using +/// their respective constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds +/// results to the nearest representable value. It can even be set to +/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) +/// to synchronize the rounding mode with that of the built-in single-precision +/// implementation (which is likely `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE +#define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to +/// a positive value signaling the overflow of an operation, in particular it +/// just evaluates to positive infinity. +/// +/// **See also:** Documentation for +/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, +/// or faster than, a separate half-precision multiplication followed by an +/// addition, which is always the case. +/// +/// **See also:** Documentation for +/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to +/// the rounding mode used for half-precision operations. It is an alias for +/// [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for +/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 +#define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN +#define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL +#define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO +#define FP_ZERO 1 +#endif +#ifndef FP_NAN +#define FP_NAN 2 +#endif +#ifndef FP_INFINITE +#define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL +#define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) +#define FE_INVALID 0x10 +#define FE_DIVBYZERO 0x08 +#define FE_OVERFLOW 0x04 +#define FE_UNDERFLOW 0x02 +#define FE_INEXACT 0x01 +#define FE_ALL_EXCEPT \ + (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) +#endif + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float { +class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS +/// Library-defined half-precision literals. +/// Import this namespace to enable half-precision floating-point literals: +/// ~~~~{.cpp} +/// using namespace half_float::literal; +/// half_float::half = 4.2_h; +/// ~~~~ +namespace literal { +half operator"" _h(long double); +} +#endif + +/// \internal +/// \brief Implementation details. +namespace detail { +#if HALF_ENABLE_CPP11_TYPE_TRAITS +/// Conditional type. +template +struct conditional : std::conditional {}; + +/// Helper for tag dispatching. +template struct bool_type : std::integral_constant {}; +using std::false_type; +using std::true_type; + +/// Type traits for floating-point types. +template struct is_float : std::is_floating_point {}; +#else +/// Conditional type. +template struct conditional { + typedef T type; +}; +template struct conditional { + typedef F type; +}; + +/// Helper for tag dispatching. +template struct bool_type {}; +typedef bool_type true_type; +typedef bool_type false_type; + +/// Type traits for floating-point types. +template struct is_float : false_type {}; +template struct is_float : is_float {}; +template struct is_float : is_float {}; +template struct is_float : is_float {}; +template <> struct is_float : true_type {}; +template <> struct is_float : true_type {}; +template <> struct is_float : true_type {}; +#endif + +/// Type traits for floating-point bits. +template struct bits { + typedef unsigned char type; +}; +template struct bits : bits {}; +template struct bits : bits {}; +template struct bits : bits {}; + +#if HALF_ENABLE_CPP11_CSTDINT +/// Unsigned integer of (at least) 16 bits width. +typedef std::uint_least16_t uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef std::uint_fast32_t uint32; + +/// Fastest signed integer of (at least) 32 bits width. +typedef std::int_fast32_t int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> struct bits { + typedef std::uint_least32_t type; +}; + +/// Unsigned integer of (at least) 64 bits width. +template <> struct bits { + typedef std::uint_least64_t type; +}; +#else +/// Unsigned integer of (at least) 16 bits width. +typedef unsigned short uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef unsigned long uint32; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef long int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits + : conditional::digits >= 32, unsigned int, + unsigned long> {}; + +#if HALF_ENABLE_CPP11_LONG_LONG +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits + : conditional::digits >= 64, + unsigned long, unsigned long long> {}; +#else +/// Unsigned integer of (at least) 64 bits width. +template <> struct bits { + typedef unsigned long type; +}; +#endif +#endif + +#ifdef HALF_ARITHMETIC_TYPE +/// Type to use for arithmetic computations and mathematical functions +/// internally. +typedef HALF_ARITHMETIC_TYPE internal_t; +#endif + +/// Tag type for binary construction. +struct binary_t {}; + +/// Tag for binary construction. +HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + +/// \name Implementation defined classification and arithmetic +/// \{ + +/// Check for infinity. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if infinity +/// \retval false else +template bool builtin_isinf(T arg) { +#if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); +#elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && + !::_isnan(static_cast(arg)); +#else + return arg == std::numeric_limits::infinity() || + arg == -std::numeric_limits::infinity(); +#endif +} + +/// Check for NaN. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if not a number +/// \retval false else +template bool builtin_isnan(T arg) { +#if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); +#elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; +#else + return arg != arg; +#endif +} + +/// Check sign. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if signbit set +/// \retval false else +template bool builtin_signbit(T arg) { +#if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); +#else + return arg < T() || (arg == T() && T(1) / arg < T()); +#endif +} + +/// Platform-independent sign mask. +/// \param arg integer value in two's complement +/// \retval -1 if \a arg negative +/// \retval 0 if \a arg positive +inline uint32 sign_mask(uint32 arg) { + static const int N = std::numeric_limits::digits - 1; +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; +#else + return -((arg >> N) & 1); +#endif +} + +/// Platform-independent arithmetic right shift. +/// \param arg integer value in two's complement +/// \param i shift amount (at most 31) +/// \return \a arg right shifted for \a i bits with possible sign extension +inline uint32 arithmetic_shift(uint32 arg, int i) { +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; +#else + return static_cast(arg) / (static_cast(1) << i) - + ((arg >> (std::numeric_limits::digits - 1)) & 1); +#endif +} + +/// \} +/// \name Error handling +/// \{ + +/// Internal exception flags. +/// \return reference to global exception flags +inline int &errflags() { + HALF_THREAD_LOCAL int flags = 0; + return flags; +} + +/// Raise floating-point exception. +/// \param flags exceptions to raise +/// \param cond condition to raise exceptions for +inline void raise(int HALF_UNUSED_NOERR(flags), + bool HALF_UNUSED_NOERR(cond) = true) { +#if HALF_ERRHANDLING + if (!cond) + return; +#if HALF_ERRHANDLING_FLAGS + errflags() |= flags; +#endif +#if HALF_ERRHANDLING_ERRNO + if (flags & FE_INVALID) + errno = EDOM; + else if (flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) + errno = ERANGE; +#endif +#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); +#endif +#ifdef HALF_ERRHANDLING_THROW_INVALID + if (flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); +#endif +#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if (flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); +#endif +#ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if (flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if (flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_INEXACT + if (flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); +#endif +#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if ((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + detail::raise(FE_INEXACT); +#endif +#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if ((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + detail::raise(FE_INEXACT); +#endif +#endif +} + +/// Check and signal for any NaN. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \retval true if either \a x or \a y is NaN +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) { +#if HALF_ERRHANDLING + detail::raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); +#endif + return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; +} + +/// Signal and silence signaling NaN. +/// \param nan half-precision NaN value +/// \return quiet NaN +/// \exception FE_INVALID if \a nan is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) { +#if HALF_ERRHANDLING + detail::raise(FE_INVALID, !(nan & 0x200)); +#endif + return nan | 0x200; +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, + unsigned int y) { +#if HALF_ERRHANDLING + detail::raise(FE_INVALID, ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || + ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \param z third half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, + unsigned int z) { +#if HALF_ERRHANDLING + detail::raise(FE_INVALID, ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || + ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || + ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) + : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) + : (z | 0x200); +} + +/// Select value or signaling NaN. +/// \param x preferred half-precision value +/// \param y ignored half-precision value except for signaling NaN +/// \return \a y if signaling NaN, \a x otherwise +/// \exception FE_INVALID if \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int +select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) { +#if HALF_ERRHANDLING + return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? detail::signal(y) : x; +#else + return x; +#endif +} + +/// Raise domain error and return NaN. +/// return quiet NaN +/// \exception FE_INVALID +inline HALF_CONSTEXPR_NOERR unsigned int invalid() { +#if HALF_ERRHANDLING + detail::raise(FE_INVALID); +#endif + return 0x7FFF; +} + +/// Raise pole error and return infinity. +/// \param sign half-precision value with sign bit only +/// \return half-precision infinity with sign of \a sign +/// \exception FE_DIVBYZERO +inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) { +#if HALF_ERRHANDLING + detail::raise(FE_DIVBYZERO); +#endif + return sign | 0x7C00; +} + +/// Check value for underflow. +/// \param arg non-zero half-precision value to check +/// \return \a arg +/// \exception FE_UNDERFLOW if arg is subnormal +inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) { +#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + detail::raise(FE_UNDERFLOW, !(arg & 0x7C00)); +#endif + return arg; +} + +/// \} +/// \name Conversion and rounding +/// \{ + +/// Half-precision overflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded overflowing half-precision value +/// \exception FE_OVERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) { +#if HALF_ERRHANDLING + detail::raise(FE_OVERFLOW); +#endif + return (R == std::round_toward_infinity) ? (sign + 0x7C00 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) + ? (sign + 0x7BFF + (sign >> 15)) + : (R == std::round_toward_zero) ? (sign | 0x7BFF) + : (sign | 0x7C00); +} + +/// Half-precision underflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded underflowing half-precision value +/// \exception FE_UNDERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) { +#if HALF_ERRHANDLING + detail::raise(FE_UNDERFLOW); +#endif + return (R == std::round_toward_infinity) ? (sign + 1 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) + : sign; +} + +/// Round half-precision number. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only +/// for rounded results \param value finite half-precision number to round +/// \param g guard bit (most significant discarded bit) +/// \param s sticky bit (or of all but the most significant discarded bits) +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) { +#if HALF_ERRHANDLING + value += (R == std::round_to_nearest) ? (g & (s | value)) + : (R == std::round_toward_infinity) ? (~(value >> 15) & (g | s)) + : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) + : 0; + if ((value & 0x7C00) == 0x7C00) + detail::raise(FE_OVERFLOW); + else if (value & 0x7C00) + detail::raise(FE_INEXACT, I || (g | s) != 0); + else + detail::raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || + (g | s) != 0); + return value; +#else + return (R == std::round_to_nearest) ? (value + (g & (s | value))) + : (R == std::round_toward_infinity) + ? (value + (~(value >> 15) & (g | s))) + : (R == std::round_toward_neg_infinity) + ? (value + ((value >> 15) & (g | s))) + : value; +#endif +} + +/// Round half-precision number to nearest integer value. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never +/// raise it \param value half-precision value to round \return nearest integral +/// half-precision value \exception FE_INVALID for signaling NaN \exception +/// FE_INEXACT if value had to be rounded and \a I is `true` +template +unsigned int integral(unsigned int value) { + unsigned int abs = value & 0x7FFF; + if (abs < 0x3C00) { + detail::raise(FE_INEXACT, I); + return ((R == std::round_to_nearest) + ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) + : (R == std::round_toward_infinity) + ? (0x3C00 & -(~(value >> 15) & (abs != 0))) + : (R == std::round_toward_neg_infinity) + ? (0x3C00 & -static_cast(value > 0x8000)) + : 0) | + (value & 0x8000); + } + if (abs >= 0x6400) + return (abs > 0x7C00) ? detail::signal(value) : value; + unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; + detail::raise(FE_INEXACT, I && (value & mask)); + return (((R == std::round_to_nearest) + ? ((1 << (exp - 1)) - (~(value >> exp) & E)) + : (R == std::round_toward_infinity) ? (mask & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) + : 0) + + value) & + ~mask; +} + +/// Convert fixed point to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam F number of fractional bits in [11,31] +/// \tparam S `true` for signed, `false` for unsigned +/// \tparam N `true` for additional normalization step, `false` if already +/// normalized to 1.F \tparam I `true` to always raise INEXACT exception, +/// `false` to raise only for rounded results \param m mantissa in Q1.F fixed +/// point format \param exp biased exponent - 1 \param sign half-precision value +/// with sign bit only \param s sticky bit (or of all but the most significant +/// already discarded bits) \return value converted to half-precision \exception +/// FE_OVERFLOW on overflows \exception FE_UNDERFLOW on underflows \exception +/// FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, + int s = 0) { + if (S) { + uint32 msign = sign_mask(m); + m = (m ^ msign) - msign; + sign = msign & 0x8000; + } + if (N) + for (; m < (static_cast(1) << F) && exp; m <<= 1, --exp) + ; + else if (exp < 0) + return rounded( + sign + static_cast(m >> (F - 10 - exp)), + static_cast((m >> (F - 11 - exp)) & 1), + s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); + return rounded( + sign + (exp << 10) + static_cast(m >> (F - 10)), + static_cast((m >> (F - 11)) & 1), + s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); +} + +/// Convert IEEE single-precision to half-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). \tparam R +/// rounding mode to use \param value single-precision value to convert \return +/// rounded half-precision value \exception FE_OVERFLOW on overflows \exception +/// FE_UNDERFLOW on underflows \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(float value, true_type) { +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph( + _mm_set_ss(value), + (R == std::round_to_nearest) ? _MM_FROUND_TO_NEAREST_INT + : (R == std::round_toward_zero) ? _MM_FROUND_TO_ZERO + : (R == std::round_toward_infinity) ? _MM_FROUND_TO_POS_INF + : (R == std::round_toward_neg_infinity) ? _MM_FROUND_TO_NEG_INF + : _MM_FROUND_CUR_DIRECTION)); +#else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); +#if 1 + unsigned int sign = (fbits >> 16) & 0x8000; + fbits &= 0x7FFFFFFF; + if (fbits >= 0x7F800000) + return sign | 0x7C00 | + ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); + if (fbits >= 0x47800000) + return overflow(sign); + if (fbits >= 0x38800000) + return rounded(sign | (((fbits >> 23) - 112) << 10) | + ((fbits >> 13) & 0x3FF), + (fbits >> 12) & 1, (fbits & 0xFFF) != 0); + if (fbits >= 0x33000000) { + int i = 125 - (fbits >> 23); + fbits = (fbits & 0x7FFFFF) | 0x800000; + return rounded(sign | (fbits >> (i + 1)), (fbits >> i) & 1, + (fbits & ((static_cast(1) << i) - 1)) != + 0); + } + if (fbits != 0) + return underflow(sign); + return sign; +#else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, + 0x0020, 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, + 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, + 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, + 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, + 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, + 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, + 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, 0xC400, 0xC800, + 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, + 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, + 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 13}; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); + return rounded(base_table[sexp] + (fbits >> i), (m >> (i - 1)) & 1, + (((static_cast(1) << (i - 1)) - 1) & m) != + 0); +#endif +#endif +} + +/// Convert IEEE double-precision to half-precision. +/// \tparam R rounding mode to use +/// \param value double-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(double value, true_type) { +#if HALF_ENABLE_F16C_INTRINSICS + if (R == std::round_indeterminate) + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), + _MM_FROUND_CUR_DIRECTION)); +#endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi >> 16) & 0x8000; + hi &= 0x7FFFFFFF; + if (hi >= 0x7FF00000) + return sign | 0x7C00 | + ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); + if (hi >= 0x40F00000) + return overflow(sign); + if (hi >= 0x3F100000) + return rounded( + sign | static_cast(((hi >> 20) - 1008) << 10) | + static_cast((hi >> 10) & 0x3FF), + static_cast((hi >> 9) & 1), ((hi & 0x1FF) | lo) != 0); + if (hi >= 0x3E600000) { + int i = static_cast(1018 - (hi >> 20)); + hi = (hi & 0xFFFFF) | 0x100000; + return rounded( + sign | static_cast(hi >> (i + 1)), + static_cast((hi >> i) & 1), + ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); + } + if ((hi | lo) != 0) + return underflow(sign); + return sign; +} + +/// Convert non-IEEE floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(T value, ...) { + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if (value == T()) + return hbits; + if (builtin_isnan(value)) + return hbits | 0x7FFF; + if (builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if (exp > 16) + return overflow(hbits); + if (exp < -13) + value = std::ldexp(value, 25); + else { + value = std::ldexp(value, 12 - exp); + hbits |= ((exp + 13) << 10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits + (m >> 1), m & 1, frac != T()); +} + +/// Convert floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half(T value) { + return float2half_impl( + value, bool_type::is_iec559 && + sizeof(typename bits::type) == sizeof(T)>()); +} + +/// Convert integer to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam T type to convert (builtin integer type) +/// \param value integral value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_INEXACT if value had to be rounded +template unsigned int int2half(T value) { + unsigned int bits = static_cast(value < 0) << 15; + if (!value) + return bits; + if (value > 0xFFE0 || (bits && value < -0xFFE0)) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + if (bits) + m = -m; + for (; m < 0x400; m <<= 1, --exp) + ; + for (; m > 0x7FF; m >>= 1, ++exp) + ; + bits |= (exp << 10) + m; + return (exp > 24) + ? rounded(bits, + static_cast(value >> (exp - 25)) & 1, + (((1 << (exp - 25)) - 1) & value) != 0) + : bits; +} + +/// Convert half-precision to IEEE single-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). \param +/// value half-precision value to convert \return single-precision value +inline float half2float_impl(unsigned int value, float, true_type) { +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); +#else +#if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } +#else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, + 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, + 0x35400000, 0x35500000, 0x35600000, 0x35700000, 0x35800000, 0x35880000, + 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, + 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, + 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, 0x360C0000, + 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, + 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, + 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, + 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, + 0x36700000, 0x36740000, 0x36780000, 0x367C0000, 0x36800000, 0x36820000, + 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, + 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, + 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, + 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, + 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, + 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, + 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, + 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, 0x36E00000, 0x36E20000, + 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, + 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, + 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, + 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, + 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, + 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, + 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, 0x37210000, + 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, + 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, + 0x372E0000, 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, + 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, + 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, + 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, + 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, + 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, + 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, + 0x375E0000, 0x375F0000, 0x37600000, 0x37610000, 0x37620000, 0x37630000, + 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, + 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, + 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, + 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, + 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, + 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, + 0x37870000, 0x37878000, 0x37880000, 0x37888000, 0x37890000, 0x37898000, + 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, + 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, + 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, + 0x37960000, 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, + 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, + 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, + 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, + 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, + 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, + 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, + 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, + 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, + 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, + 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, + 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, + 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, + 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, + 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, + 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, 0x37C80000, 0x37C88000, + 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, + 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, + 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, + 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, + 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, + 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, + 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, + 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, 0x37E00000, 0x37E08000, + 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, + 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, + 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, + 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, + 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, + 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, + 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, + 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, 0x37F88000, + 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, + 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, + 0x37FF0000, 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, + 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, + 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, + 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, + 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, + 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, + 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, + 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, + 0x380B8000, 0x380BC000, 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, + 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, + 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, + 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, + 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, + 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, + 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, + 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, + 0x38178000, 0x3817C000, 0x38180000, 0x38184000, 0x38188000, 0x3818C000, + 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, + 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, + 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, + 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, + 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, + 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, + 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, + 0x38238000, 0x3823C000, 0x38240000, 0x38244000, 0x38248000, 0x3824C000, + 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, + 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, + 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, + 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, + 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, + 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, + 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, 0x3830C000, + 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, + 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, + 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, + 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, + 0x38370000, 0x38374000, 0x38378000, 0x3837C000, 0x38380000, 0x38384000, + 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, + 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, + 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, + 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, + 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, + 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, + 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, + 0x38430000, 0x38434000, 0x38438000, 0x3843C000, 0x38440000, 0x38444000, + 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, + 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, + 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, + 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, + 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, + 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, + 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, + 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, 0x38504000, + 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, + 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, + 0x38538000, 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, + 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, + 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, + 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, + 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, + 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, + 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, + 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, + 0x385F8000, 0x385FC000, 0x38600000, 0x38604000, 0x38608000, 0x3860C000, + 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, + 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, + 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, + 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, + 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, + 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, + 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, + 0x386B8000, 0x386BC000, 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, + 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, + 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, + 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, + 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, + 0x38730000, 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, + 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, + 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, + 0x38778000, 0x3877C000, 0x38780000, 0x38784000, 0x38788000, 0x3878C000, + 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, + 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, + 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, + 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, + 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, + 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, + 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, 0x38026000, + 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, + 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, + 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, + 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, 0x38060000, 0x38062000, + 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, + 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, + 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, + 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, + 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, + 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, + 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, + 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, 0x380C0000, 0x380C2000, + 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, + 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, + 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, + 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, + 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, + 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, + 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, 0x38122000, + 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, + 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, + 0x3813C000, 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, + 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, + 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, + 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, + 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, + 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, + 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, + 0x3819C000, 0x3819E000, 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, + 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, + 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, + 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, + 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, + 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, + 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, + 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, + 0x381FC000, 0x381FE000, 0x38200000, 0x38202000, 0x38204000, 0x38206000, + 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, + 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, + 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, + 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, + 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, + 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, + 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, 0x38264000, 0x38266000, + 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, + 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, + 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, + 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, + 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, + 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, + 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, + 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, + 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, + 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, + 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, + 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, 0x38300000, 0x38302000, + 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, + 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, + 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, + 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, + 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, + 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, + 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, 0x38360000, 0x38362000, + 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, + 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, + 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, + 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, + 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, + 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, + 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, + 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, 0x383C2000, + 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, + 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, + 0x383DC000, 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, + 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, + 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, + 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, + 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, + 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, + 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, + 0x3843C000, 0x3843E000, 0x38440000, 0x38442000, 0x38444000, 0x38446000, + 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, + 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, + 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, + 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, + 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, + 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, + 0x3849C000, 0x3849E000, 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, + 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, + 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, + 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, + 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, + 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, + 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, + 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, + 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, 0x38504000, 0x38506000, + 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, + 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, + 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, + 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, + 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, + 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, + 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, 0x38566000, + 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, + 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, + 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, + 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, 0x385A0000, 0x385A2000, + 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, + 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, + 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, + 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, + 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, + 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, + 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, + 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, 0x38600000, 0x38602000, + 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, + 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, + 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, + 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, + 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, + 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, + 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, 0x38662000, + 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, + 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, + 0x3867C000, 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, + 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, + 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, + 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, + 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, + 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, + 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, + 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, + 0x386DC000, 0x386DE000, 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, + 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, + 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, + 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, + 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, + 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, + 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, + 0x3873C000, 0x3873E000, 0x38740000, 0x38742000, 0x38744000, 0x38746000, + 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, + 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, + 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, + 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, + 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, + 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, + 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, + 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, + 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, + 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, + 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, + 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, + 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, + 0x387FC000, 0x387FE000}; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, + 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, + 0x06000000, 0x06800000, 0x07000000, 0x07800000, 0x08000000, 0x08800000, + 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, + 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, + 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, 0x81800000, + 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, + 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, + 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, + 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000}; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 0, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + bits::type fbits = + mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + + exponent_table[value >> 10]; +#endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; +#endif +} + +/// Convert half-precision to IEEE double-precision. +/// \param value half-precision value to convert +/// \return double-precision value +inline double half2float_impl(unsigned int value, double, true_type) { +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); +#else + uint32 hi = static_cast(value & 0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if (abs) { + hi |= 0x3F000000 << static_cast(abs >= 0x7C00); + for (; abs < 0x400; abs <<= 1, hi -= 0x100000) + ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; +#endif +} + +/// Convert half-precision to non-IEEE floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template T half2float_impl(unsigned int value, T, ...) { + T out; + unsigned int abs = value & 0x7FFF; + if (abs > 0x7C00) + out = (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) + ? std::numeric_limits::signaling_NaN() + : std::numeric_limits::has_quiet_NaN + ? std::numeric_limits::quiet_NaN() + : T(); + else if (abs == 0x7C00) + out = std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + else if (abs > 0x3FF) + out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); + else + out = std::ldexp(static_cast(abs), -24); + return (value & 0x8000) ? -out : out; +} + +/// Convert half-precision to floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template T half2float(unsigned int value) { + return half2float_impl( + value, T(), + bool_type::is_iec559 && + sizeof(typename bits::type) == sizeof(T)>()); +} + +/// Convert half-precision floating-point to integer. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never +/// raise it \tparam T type to convert to (builtin integer type with at least 16 +/// bits precision, excluding any implicit sign bits) \param value +/// half-precision value to convert \return rounded integer value \exception +/// FE_INVALID if value is not representable in type \a T \exception FE_INEXACT +/// if value had to be rounded and \a I is `true` +template +T half2int(unsigned int value) { + unsigned int abs = value & 0x7FFF; + if (abs >= 0x7C00) { + detail::raise(FE_INVALID); + return (value & 0x8000) ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + if (abs < 0x3800) { + detail::raise(FE_INEXACT, I); + return (R == std::round_toward_infinity) ? T(~(value >> 15) & (abs != 0)) + : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) + : T(); + } + int exp = 25 - (abs >> 10); + unsigned int m = (value & 0x3FF) | 0x400; + int32 i = static_cast( + (exp <= 0) ? (m << -exp) + : ((m + ((R == std::round_to_nearest) + ? ((1 << (exp - 1)) - (~(m >> exp) & E)) + : (R == std::round_toward_infinity) + ? (((1 << exp) - 1) & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) + ? (((1 << exp) - 1) & -(value >> 15)) + : 0)) >> + exp)); + if ((!std::numeric_limits::is_signed && (value & 0x8000)) || + (std::numeric_limits::digits < 16 && + ((value & 0x8000) ? (-i < std::numeric_limits::min()) + : (i > std::numeric_limits::max())))) + detail::raise(FE_INVALID); + else if (I && exp > 0 && (m & ((1 << exp) - 1))) + detail::raise(FE_INEXACT); + return static_cast((value & 0x8000) ? -i : i); +} + +/// \} +/// \name Mathematics +/// \{ + +/// upper part of 64-bit multiplication. +/// \tparam R rounding mode to use +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y +template uint32 mulhi(uint32 x, uint32 y) { + uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), + c = (xy & 0xFFFF) + (yx & 0xFFFF) + + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); + return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + + ((R == std::round_to_nearest) ? ((c >> 15) & 1) + : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) + : 0); +} + +/// 64-bit multiplication. +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y rounded to nearest +inline uint32 multiply64(uint32 x, uint32 y) { +#if HALF_ENABLE_CPP11_LONG_LONG + return static_cast( + (static_cast(x) * static_cast(y) + + 0x80000000) >> + 32); +#else + return mulhi(x, y); +#endif +} + +/// 64-bit division. +/// \param x upper 32 bit of dividend +/// \param y divisor +/// \param s variable to store sticky bit for rounding +/// \return (\a x << 32) / \a y +inline uint32 divide64(uint32 x, uint32 y, int &s) { +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx % y != 0), static_cast(xx / y); +#else + y >>= 1; + uint32 rem = x, div = 0; + for (unsigned int i = 0; i < 32; ++i) { + div <<= 1; + if (rem >= y) { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; +#endif +} + +/// Half precision positive modulus. +/// \tparam Q `true` to compute full quotient, `false` else +/// \tparam R `true` to compute signed remainder, `false` for positive remainder +/// \param x first operand as positive finite half-precision value +/// \param y second operand as positive finite half-precision value +/// \param quo address to store quotient at, `nullptr` if \a Q `false` +/// \return modulus of \a x / \a y +template +unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) { + unsigned int q = 0; + if (x > y) { + int absx = x, absy = y, expx = 0, expy = 0; + for (; absx < 0x400; absx <<= 1, --expx) + ; + for (; absy < 0x400; absy <<= 1, --expy) + ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + for (int d = expx - expy; d; --d) { + if (!Q && mx == my) + return 0; + if (mx >= my) { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if (!Q && mx == my) + return 0; + if (mx >= my) { + mx -= my; + ++q; + } + if (Q) { + q &= (1 << (std::numeric_limits::digits - 1)) - 1; + if (!mx) + return *quo = q, 0; + } + for (; mx < 0x400; mx <<= 1, --expy) + ; + x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); + } + if (R) { + unsigned int a, b; + if (y < 0x800) { + a = (x < 0x400) ? (x << 1) : (x + 0x400); + b = y; + } else { + a = x; + b = y - 0x400; + } + if (a > b || (a == b && (q & 1))) { + int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); + int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - + (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); + for (; m < 0x800 && exp > 1; m <<= 1, --exp) + ; + x = 0x8000 + ((exp - 1) << 10) + (m >> 1); + q += Q; + } + } + if (Q) + *quo = q; + return x; +} + +/// Fixed point square root. +/// \tparam F number of fractional bits +/// \param r radicand in Q1.F fixed point format +/// \param exp exponent +/// \return square root as Q1.F/2 +template uint32 sqrt(uint32 &r, int &exp) { + int i = exp & 1; + r <<= i; + exp = (exp - i) / 2; + uint32 m = 0; + for (uint32 bit = static_cast(1) << F; bit; bit >>= 2) { + if (r < m + bit) + m >>= 1; + else { + r -= m + bit; + m = (m >> 1) + bit; + } + } + return m; +} + +/// Fixed point binary exponential. +/// This uses the BKM algorithm in E-mode. +/// \param m exponent in [0,1) as Q0.31 +/// \param n number of iterations (at most 32) +/// \return 2 ^ \a m as Q1.31 +inline uint32 exp2(uint32 m, unsigned int n = 32) { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, + 0x02DCF2D1, 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, + 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, + 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, + 0x00000003, 0x00000001}; + if (!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for (unsigned int i = 1; i < n; ++i) { + uint32 mz = my + logs[i]; + if (mz <= m) { + my = mz; + mx += mx >> i; + } + } + return mx; +} + +/// Fixed point binary logarithm. +/// This uses the BKM algorithm in L-mode. +/// \param m mantissa in [1,2) as Q1.30 +/// \param n number of iterations (at most 32) +/// \return log2(\a m) as Q0.31 +inline uint32 log2(uint32 m, unsigned int n = 32) { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, + 0x02DCF2D1, 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, + 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, + 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, + 0x00000003, 0x00000001}; + if (m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for (unsigned int i = 1; i < n; ++i) { + uint32 mz = mx + (mx >> i); + if (mz <= m) { + mx = mz; + my += logs[i]; + } + } + return my; +} + +/// Fixed point sine and cosine. +/// This uses the CORDIC algorithm in rotation mode. +/// \param mz angle in [-pi/2,pi/2] as Q1.30 +/// \param n number of iterations (at most 31) +/// \return sine and cosine of \a mz as Q1.30 +inline std::pair sincos(uint32 mz, unsigned int n = 31) { + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, + 0x00FFFAAB, 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, + 0x00040000, 0x00020000, 0x00010000, 0x00008000, 0x00004000, 0x00002000, + 0x00001000, 0x00000800, 0x00000400, 0x00000200, 0x00000100, 0x00000080, + 0x00000040, 0x00000020, 0x00000010, 0x00000008, 0x00000004, 0x00000002, + 0x00000001}; + uint32 mx = 0x26DD3B6A, my = 0; + for (unsigned int i = 0; i < n; ++i) { + uint32 sign = sign_mask(mz); + uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; + uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; + mx = tx; + my = ty; + mz -= (angles[i] ^ sign) - sign; + } + return std::make_pair(my, mx); +} + +/// Fixed point arc tangent. +/// This uses the CORDIC algorithm in vectoring mode. +/// \param my y coordinate as Q0.30 +/// \param mx x coordinate as Q0.30 +/// \param n number of iterations (at most 31) +/// \return arc tangent of \a my / \a mx as Q1.30 +inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) { + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, + 0x00FFFAAB, 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, + 0x00040000, 0x00020000, 0x00010000, 0x00008000, 0x00004000, 0x00002000, + 0x00001000, 0x00000800, 0x00000400, 0x00000200, 0x00000100, 0x00000080, + 0x00000040, 0x00000020, 0x00000010, 0x00000008, 0x00000004, 0x00000002, + 0x00000001}; + uint32 mz = 0; + for (unsigned int i = 0; i < n; ++i) { + uint32 sign = sign_mask(my); + uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; + uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; + mx = tx; + my = ty; + mz += (angles[i] ^ sign) - sign; + } + return mz; +} + +/// Reduce argument for trigonometric functions. +/// \param abs half-precision floating-point value +/// \param k value to take quarter period +/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 +inline uint32 angle_arg(unsigned int abs, int &k) { + uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + int exp = (abs >> 10) + (abs <= 0x3FF) - 15; + if (abs < 0x3A48) + return k = 0, m << (exp + 20); +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, + yi = (y + (mask >> 1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f >> 63); + k = static_cast(yi >> (62 - exp)); + return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), + 0xC90FDAA2) ^ + sign) - + sign; +#else + uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), + yl = (m * 0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1) << (30 - exp)) - 1, + yi = (yh + (mask >> 1)) & ~mask, sign = -static_cast(yi > yh); + k = static_cast(yi >> (30 - exp)); + uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; + return (multiply64((exp > -1) ? (((fh << (1 + exp)) & 0xFFFFFFFF) | + ((fl & 0xFFFFFFFF) >> (31 - exp))) + : fh, + 0xC90FDAA2) ^ + sign) - + sign; +#endif +} + +/// Get arguments for atan2 function. +/// \param abs half-precision floating-point value +/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 +inline std::pair atan2_args(unsigned int abs) { + int exp = -15; + for (; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - ((rexp > -31) + ? ((r >> -rexp) | + ((r & ((static_cast(1) << -rexp) - 1)) != 0)) + : 1); + for (rexp = 0; r < 0x40000000; r <<= 1, --rexp) + ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if (d < 0) + return std::make_pair((d < -14) + ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) + : (my << (14 + d)), + (mx << 14) + (r << 13) / mx); + if (d > 0) + return std::make_pair( + my << 14, + (d > 14) + ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) + : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); + return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); +} + +/// Get exponentials for hyperbolic computation +/// \param abs half-precision floating-point value +/// \param exp variable to take unbiased exponent of larger result +/// \param n number of BKM iterations (at most 32) +/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent +inline std::pair hyperbolic_args(unsigned int abs, int &exp, + unsigned int n = 32) { + uint32 mx = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29), + my; + int e = (abs >> 10) + (abs <= 0x3FF); + if (e < 14) { + exp = 0; + mx >>= 14 - e; + } else { + exp = static_cast(mx >> (45 - e)); + mx = (mx << (e - 14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if (mx > 0x80000000) { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } else + my = mx; + return std::make_pair( + mx, (d < 31) + ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) + : 1); +} + +/// Postprocessing for binary exponential. +/// \tparam R rounding mode to use +/// \param m fractional part of exponent as Q0.31 +/// \param exp absolute value of unbiased exponent +/// \param esign sign of actual exponent +/// \param sign sign bit of result +/// \param n number of BKM iterations (at most 32) +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0, + unsigned int n = 32) { + if (esign) { + exp = -exp - (m != 0); + if (exp < -25) + return underflow(sign); + else if (exp == -25) + return rounded(sign, 1, m != 0); + } else if (exp > 15) + return overflow(sign); + if (!m) + return sign | + (((exp += 15) > 0) ? (exp << 10) : check_underflow(0x200 >> -exp)); + m = exp2(m, n); + int s = 0; + if (esign) + m = divide64(0x80000000, m, s); + return fixed2half(m, exp + 14, sign, s); +} + +/// Postprocessing for binary logarithm. +/// \tparam R rounding mode to use +/// \tparam L logarithm for base transformation as Q1.31 +/// \param m fractional part of logarithm as Q0.31 +/// \param ilog signed integer part of logarithm +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return value base-transformed and converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) { + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; + if (!m) + return 0; + for (; m < 0x80000000; m <<= 1, --exp) + ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if (exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); +} + +/// Hypotenuse square root and postprocessing. +/// \tparam R rounding mode to use +/// \param r mantissa as Q2.30 +/// \param exp biased exponent +/// \return square root converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template unsigned int hypot_post(uint32 r, int exp) { + int i = static_cast(r >> 31); + if ((exp += i) > 46) + return overflow(); + if (exp < -34) + return underflow(); + r = (r >> i) | (r & i); + uint32 m = sqrt<30>(r, exp += 15); + return fixed2half(m, exp - 1, 0, r != 0); +} + +/// Division and postprocessing for tangents. +/// \tparam R rounding mode to use +/// \param my dividend as Q1.31 +/// \param mx divisor as Q1.31 +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return quotient converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int tangent_post(uint32 my, uint32 mx, int exp, + unsigned int sign = 0) { + int i = my >= mx, s; + exp += i; + if (exp > 29) + return overflow(sign); + if (exp < -11) + return underflow(sign); + uint32 m = divide64(my >> (i + 1), mx, s); + return fixed2half(m, exp, sign, s); +} + +/// Area function and postprocessing. +/// This computes the value directly in Q2.30 using the representation +/// `asinh|acosh(x) = log(x+sqrt(x^2+|-1))`. \tparam R rounding mode to use +/// \tparam S `true` for asinh, `false` for acosh +/// \param arg half-precision argument +/// \return asinh|acosh(\a arg) converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int area(unsigned int arg) { + int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, + ilog; + uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, + my, r; + for (; abs < 0x400; abs <<= 1, --expy) + ; + expy += abs >> 10; + r = ((abs & 0x3FF) | 0x400) << 5; + r *= r; + int i = static_cast(r >> 31); + expy = 2 * expy + i; + r >>= i; + if (S) { + if (expy < 0) { + r = 0x40000000 + + ((expy > -30) ? ((r >> -expy) | + ((r & ((static_cast(1) << -expy) - 1)) != 0)) + : 1); + expy = 0; + } else { + r += 0x40000000 >> expy; + i = static_cast(r >> 31); + r = (r >> i) | (r & i); + expy += i; + } + } else { + r -= 0x40000000 >> expy; + for (; r < 0x40000000; r <<= 1, --expy) + ; + } + my = sqrt<30>(r, expy); + my = (my << 15) + (r << 14) / my; + if (S) { + mx >>= expy - expx; + ilog = expy; + } else { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = static_cast(my >> 31); + static const int G = S && (R == std::round_to_nearest); + return log2_post(log2(my >> i, 26 + S + G) + (G << 3), + ilog + i, 17, + arg & (static_cast(S) << 15)); +} + +/// Class for 1.31 unsigned floating-point computation +struct f31 { + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) { + for (; abs < 0x400; abs <<= 1, --exp) + ; + m = static_cast((abs & 0x3FF) | 0x400) << 21; + exp += (abs >> 10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) { + if (b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); + int i = (m & 0xFFFFFFFF) < a.m; + return f31(((m + i) >> i) | 0x80000000, a.exp + i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); + if (!m) + return f31(0, -32); + for (; m < 0x80000000; m <<= 1, --exp) + ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) { + uint32 m = multiply64(a.m, b.m); + int i = static_cast(m >> 31); + return f31(m << (1 - i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m + i) >> i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. +}; + +/// Error function and postprocessing. +/// This computes the value directly in Q1.31 using the approximations given +/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). +/// \tparam R rounding mode to use +/// \tparam C `true` for comlementary error function, `false` else +/// \param arg half-precision function argument +/// \return approximated value of error function in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template unsigned int erf(unsigned int arg) { + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), + x2 = x * x * f31(0xB8AA3B29, 0), + t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), + t2 = t * t; + f31 e = + ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + + f31(0x82790637, -2) - + (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * + t / + ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) + : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), + static_cast(x2.m >> (31 - x2.exp)))); + return (!C || sign) + ? fixed2half( + 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) + : (e.exp < -25) ? underflow() + : fixed2half( + e.m >> 1, e.exp + 14, 0, e.m & 1); +} + +/// Gamma function and postprocessing. +/// This approximates the value of either the gamma function or its logarithm +/// directly in Q1.31. \tparam R rounding mode to use \tparam L `true` for +/// lograithm of gamma function, `false` for gamma function \param arg +/// half-precision floating-point value \return lgamma/tgamma(\a arg) in +/// half-precision \exception FE_OVERFLOW on overflows \exception FE_UNDERFLOW +/// on underflows \exception FE_INEXACT if \a arg is not a positive integer +template +unsigned int gamma(unsigned int arg) { + /* static const double p[] ={ 2.50662827563479526904, + 225.525584619175212544, -268.295973841304927459, 80.9030806934622512966, + -5.00757863970517583837, 0.0114684895434781459556 }; double t = arg + 4.65, + s = p[0]; for(unsigned int i=0; i<5; ++i) s += p[i+1] / (arg+i); return + std::log(s) + (arg-0.5)*std::log(t) - t; + */ + static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), + x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), + s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + + f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + + f31(0xE1868CB7, 7) / x - + f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - + f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); + int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); + s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), + i) / + lbe; + if (x.exp != -1 || x.m != 0x80000000) { + i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); + f31 l = f31((static_cast(t.exp) << (31 - i)) + + (log2(t.m >> 1, 30) >> i), + i) / + lbe; + s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) + : (s + (x - f31(0x80000000, -1)) * l); + } + s = x.exp ? (s - t) : (t - s); + if (bsign) { + if (z.exp >= 0) { + sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; + for (z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; + z.m <<= 1, --z.exp) + ; + } + if (z.exp == -1) + z = f31(0x80000000, 0) - z; + if (z.exp < -1) { + z = z * pi; + z.m = sincos(z.m >> (1 - z.exp), 30).first; + for (z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + } else + z = f31(0x80000000, 0); + } + if (L) { + if (bsign) { + f31 l(0x92868247, 0); + if (z.exp < 0) { + uint32 m = log2((z.m + 1) >> 1, 27); + z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); + for (; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + l = l + z / lbe; + } + sign = static_cast( + x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) + << 15; + s = sign ? (s - l) : x.exp ? (l - s) : (l + s); + } else { + sign = static_cast(x.exp == 0) << 15; + if (s.exp < -24) + return underflow(sign); + if (s.exp > 15) + return overflow(sign); + } + } else { + s = s * lbe; + uint32 m; + if (s.exp < 0) { + m = s.m >> -s.exp; + s.exp = 0; + } else { + m = (s.m << s.exp) & 0x7FFFFFFF; + s.exp = static_cast(s.m >> (31 - s.exp)); + } + s.m = exp2(m, 27); + if (!x.exp) + s = f31(0x80000000, 0) / s; + if (bsign) { + if (z.exp < 0) + s = s * z; + s = pi / s; + if (s.exp < -24) + return underflow(sign); + } else if (z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) + return ((s.exp + 14) << 10) + static_cast(s.m >> 21); + if (s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp + 14, sign); +} +/// \} + +template struct half_caster; +} // namespace detail + +/// Half-precision floating-point type. +/// This class implements an IEEE-conformant half-precision floating-point type +/// with the usual arithmetic operators and conversions. It is implicitly +/// convertible to single-precision floating-point, which makes arithmetic +/// expressions and functions with mixed-type operands to be of the most precise +/// operand type. +/// +/// According to the C++98/03 definition, the half type is not a POD type. But +/// according to C++11's less strict and extended definitions it is both a +/// standard layout type and a trivially copyable type (even if not a POD type), +/// which means it can be standard-conformantly copied using raw binary copies. +/// But in this context some more words about the actual size of the type. +/// Although the half is representing an IEEE 16-bit type, it does not +/// necessarily have to be of exactly 16-bits size. But on any reasonable +/// implementation the actual binary representation of this type will most +/// probably not involve any additional "magic" or padding beyond the simple +/// binary representation of the underlying 16-bit IEEE number, even if not +/// strictly guaranteed by the standard. But even then it only has an actual +/// size of 16 bits if your C++ implementation supports an unsigned integer type +/// of exactly 16 bits width. But this should be the case on nearly any +/// reasonable platform. +/// +/// So if your C++ implementation is not totally exotic or imposes special +/// alignment requirements, it is a reasonable assumption that the data of a +/// half is just comprised of the 2 bytes of the underlying IEEE representation. +class half { +public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin + /// types' default-initialization semantics and may be less efficient than no + /// initialization, it is needed to provide proper value-initialization + /// semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) + : data_( + static_cast(detail::float2half(rhs))) { + } + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half &operator=(float rhs) { + data_ = static_cast(detail::float2half(rhs)); + return *this; + } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half &operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half &operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half &operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half &operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half &operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half &operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half &operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half &operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half &operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half &operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) { + half out(*this); + ++*this; + return out; + } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) { + half out(*this); + --*this; + return out; + } + /// \} + +private: + /// Rounding mode to use + static const std::float_round_style round_style = + (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT + : data_(static_cast(bits)) {} + + /// Internal binary representation + detail::uint16 data_; + +#ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template + friend std::basic_ostream & + operator<<(std::basic_ostream &, half); + template + friend std::basic_istream & + operator>>(std::basic_istream &, half &); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int *); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char *); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half rsqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half *, half *); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); +#ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); +#endif + friend half frexp(half, int *); + friend half scalbln(half, long); + friend half modf(half, half *); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template + friend struct detail::half_caster; + friend class std::numeric_limits; +#if HALF_ENABLE_CPP11_HASH + friend struct std::hash; +#endif +#if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator"" _h(long double); +#endif +#endif +}; + +#if HALF_ENABLE_CPP11_USER_LITERALS +namespace literal { +/// Half literal. +/// While this returns a properly rounded half-precision value, half literals +/// can unfortunately not be constant expressions due to rather involved +/// conversions. So don't expect this to be a literal literal without involving +/// conversion operations at runtime. It is a convenience feature, not a +/// performance optimization. \param value literal value \return half with of +/// given value (possibly rounded) \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half operator"" _h(long double value) { + return half(detail::binary, detail::float2half(value)); +} +} // namespace literal +#endif + +namespace detail { +/// Helper class for half casts. +/// This class template has to be specialized for all valid cast arguments to +/// define an appropriate static `cast` member function and a corresponding +/// `type` member denoting its return type. \tparam T destination type \tparam U +/// source type \tparam R rounding mode to use +template +struct half_caster {}; +template struct half_caster { +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, + "half_cast from non-arithmetic type not supported"); +#endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + +private: + static half cast_impl(U arg, true_type) { + return half(binary, float2half(arg)); + } + static half cast_impl(U arg, false_type) { + return half(binary, int2half(arg)); + } +}; +template struct half_caster { +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, + "half_cast to non-arithmetic type not supported"); +#endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + +private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { + return half2int(arg.data_); + } +}; +template struct half_caster { + static half cast(half arg) { return arg; } +}; +} // namespace detail +} // namespace half_float + +/// Extensions to the C++ standard library. +namespace std { +/// Numeric limits for half-precision floats. +/// **See also:** Documentation for +/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) +template <> class numeric_limits { +public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Does not support denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + +#if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; +#else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref + /// HALF_ERRHANDLING_THROW_INVALID) is activated. + static HALF_CONSTEXPR_CONST bool traps = false; +#endif + + /// Does not support pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = + half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x0400); + } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0xFBFF); + } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x7BFF); + } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x1400); + } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, + (round_style == std::round_to_nearest) ? 0x3800 + : 0x3C00); + } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x7C00); + } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x7FFF); + } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x7DFF); + } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW { + return half_float::half(half_float::detail::binary, 0x0001); + } +}; + +#if HALF_ENABLE_CPP11_HASH +/// Hash function for half-precision floats. +/// This is only defined if C++11 `std::hash` is supported and enabled. +/// +/// **See also:** Documentation for +/// [std::hash](https://en.cppreference.com/w/cpp/utility/hash) +template <> struct hash { + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const { + return hash()( + arg.data_ & -static_cast(arg.data_ != 0x8000)); + } +}; +#endif +} // namespace std + +namespace half_float { +/// \anchor compop +/// \name Comparison operators +/// \{ + +/// Comparison for equality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for inequality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands not equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) { + return detail::compsignal(x.data_, y.data_) || + (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)); +} + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)); +} + +/// \} +/// \anchor arithmetics +/// \name Arithmetic operators +/// \{ + +/// Identity. +/// \param arg operand +/// \return unchanged operand +inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + +/// Negation. +/// \param arg operand +/// \return negated operand +inline HALF_CONSTEXPR half operator-(half arg) { + return half(detail::binary, arg.data_ ^ 0x8000); +} + +/// Addition. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return sum of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with different signs +/// or signaling NaNs \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according +/// to rounding +inline half operator+(half x, half y) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + detail::half2float(x.data_) + + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy != 0x7C00) ? x.data_ + : (sub && absx == 0x7C00) ? detail::invalid() + : y.data_); + if (!absx) + return absy ? y + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (x.data_ | y.data_) + : (x.data_ & y.data_)); + if (!absy) + return x; + unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; + if (absy > absx) + std::swap(absx, absy); + int exp = (absx >> 10) + (absx <= 0x3FF), + d = exp - (absy >> 10) - (absy <= 0x3FF), + mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; + if (d < 13) { + my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; + my = (my >> d) | ((my & ((1 << d) - 1)) != 0); + } else + my = 1; + if (sub) { + if (!(mx -= my)) + return half(detail::binary, + static_cast(half::round_style == + std::round_toward_neg_infinity) + << 15); + for (; mx < 0x2000 && exp > 1; mx <<= 1, --exp) + ; + } else { + mx += my; + int i = mx >> 14; + if ((exp += i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx >> i) | (mx & i); + } + return half(detail::binary, detail::rounded( + sign + ((exp - 1) << 10) + (mx >> 3), + (mx >> 2) & 1, (mx & 0x3) != 0)); +#endif +} + +/// Subtraction. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return difference of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or +/// signaling NaNs \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to +/// rounding +inline half operator-(half x, half y) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + detail::half2float(x.data_) - + detail::half2float(y.data_))); +#else + return x + -y; +#endif +} + +/// Multiplication. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return product of half expressions +/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is +/// signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to +/// rounding +inline half operator*(half x, half y) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + detail::half2float(x.data_) * + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) + ? detail::invalid() + : (sign | 0x7C00)); + if (!absx || !absy) + return half(detail::binary, sign); + for (; absx < 0x400; absx <<= 1, --exp) + ; + for (; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = static_cast(m >> 21), s = static_cast(m & i); + exp += (absx >> 10) + (absy >> 10) + i; + if (exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if (exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, + detail::fixed2half( + m >> i, exp, sign, s)); +#endif +} + +/// Division. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return quotient of half expressions +/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a +/// x or \a y is signaling NaN \exception FE_DIVBYZERO if dividing finite value +/// by 0 \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator/(half x, half y) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + detail::half2float(x.data_) / + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == absy) + ? detail::invalid() + : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); + if (!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if (!absy) + return half(detail::binary, detail::pole(sign)); + for (; absx < 0x400; absx <<= 1, --exp) + ; + for (; absy < 0x400; absy <<= 1, ++exp) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + int i = mx < my; + exp += (absx >> 10) - (absy >> 10) - i; + if (exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if (exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, + detail::fixed2half( + mx / my, exp, sign, mx % my != 0)); +#endif +} + +/// \} +/// \anchor streaming +/// \name Input and output +/// \{ + +/// Output operator. +/// This uses the built-in functionality for streaming out floating-point +/// numbers. +/// \param out output stream to write into +/// \param arg half expression to write +/// \return reference to output stream +template +std::basic_ostream & +operator<<(std::basic_ostream &out, half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); +#else + return out << detail::half2float(arg.data_); +#endif +} + +/// Input operator. +/// This uses the built-in functionality for streaming in floating-point +/// numbers, specifically double precision floating +/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref +/// HALF_ARITHMETIC_TYPE)). So the input string is first rounded to double +/// precision using the underlying platform's current floating-point rounding +/// mode before being rounded to half-precision using the library's +/// half-precision rounding mode. \param in input stream to read from \param arg +/// half to read into \return reference to input stream \exception FE_OVERFLOW, +/// ...UNDERFLOW, ...INEXACT according to rounding +template +std::basic_istream & +operator>>(std::basic_istream &in, half &arg) { +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; +#else + double f; +#endif + if (in >> f) + arg.data_ = detail::float2half(f); + return in; +} + +/// \} +/// \anchor basic +/// \name Basic mathematical operations +/// \{ + +/// Absolute value. +/// **See also:** Documentation for +/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). \param arg +/// operand \return absolute value of \a arg +inline HALF_CONSTEXPR half fabs(half arg) { + return half(detail::binary, arg.data_ & 0x7FFF); +} + +/// Absolute value. +/// **See also:** Documentation for +/// [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). \param arg +/// operand \return absolute value of \a arg +inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). \param x +/// first operand \param y second operand \return remainder of floating-point +/// division. \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x +/// or \a y is signaling NaN +inline half fmod(half x, half y) { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, + sign = x.data_ & 0x8000; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() + : x.data_); + if (!absy) + return half(detail::binary, detail::invalid()); + if (!absx) + return x; + if (absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign | detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is +/// signaling NaN +inline half remainder(half x, half y) { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, + sign = x.data_ & 0x8000; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() + : x.data_); + if (!absy) + return half(detail::binary, detail::invalid()); + if (absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign ^ detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). \param +/// x first operand \param y second operand \param quo address to store some +/// bits of quotient at \return remainder of floating-point division. \exception +/// FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling +/// NaN +inline half remquo(half x, half y, int *quo) { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, + value = x.data_ & 0x8000; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() + : (*quo = 0, x.data_)); + if (!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value ^ y.data_) & 0x8000) != 0; + int q = 1; + if (absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); +} + +/// Fused multiply add. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). \param x +/// first operand \param y second operand \param z third operand \return ( \a x +/// * \a y ) + \a z rounded as one operation. \exception FE_INVALID according to +/// operator*() and operator+() unless any argument is a quiet NaN and no +/// argument is a signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT +/// according to rounding the final addition +inline half fma(half x, half y, half z) { +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); +#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, + detail::float2half(std::fma(fx, fy, fz))); +#else + return half(detail::binary, + detail::float2half(fx * fy + fz)); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, + exp = -15; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + bool sub = ((sign ^ z.data_) & 0x8000) != 0; + if (absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) + ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) + : (absx == 0x7C00) + ? half(detail::binary, (!absy || (sub && absz == 0x7C00)) + ? detail::invalid() + : (sign | 0x7C00)) + : (absy == 0x7C00) + ? half(detail::binary, (!absx || (sub && absz == 0x7C00)) + ? detail::invalid() + : (sign | 0x7C00)) + : z; + if (!absx || !absy) + return absz ? z + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (z.data_ | sign) + : (z.data_ & sign)); + for (; absx < 0x400; absx <<= 1, --exp) + ; + for (; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = static_cast(m >> 21); + exp += (absx >> 10) + (absy >> 10) + i; + m <<= 3 - i; + if (absz) { + int expz = 0; + for (; absz < 0x400; absz <<= 1, --expz) + ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) + << 13; + if (expz > exp || (expz == exp && mz > m)) { + std::swap(m, mz); + std::swap(exp, expz); + if (sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d < 23) ? ((mz >> d) | + ((mz & ((static_cast(1) << d) - 1)) != 0)) + : 1; + if (sub) { + m = m - mz; + if (!m) + return half(detail::binary, + static_cast(half::round_style == + std::round_toward_neg_infinity) + << 15); + for (; m < 0x800000; m <<= 1, --exp) + ; + } else { + m += mz; + i = static_cast(m >> 24); + m = (m >> i) | (m & i); + exp += i; + } + } + if (exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if (exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, + detail::fixed2half( + m, exp - 1, sign)); +#endif +} + +/// Maximum of half expressions. +/// **See also:** Documentation for +/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). \param x +/// first operand \param y second operand \return maximum of operands, ignoring +/// quiet NaNs \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) { + return half( + detail::binary, + (!isnan(y) && + (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Minimum of half expressions. +/// **See also:** Documentation for +/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). \param x +/// first operand \param y second operand \return minimum of operands, ignoring +/// quiet NaNs \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) { + return half( + detail::binary, + (!isnan(y) && + (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Positive difference. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). \param x +/// first operand \param y second operand \return \a x - \a y or 0 if difference +/// negative \exception FE_... according to operator-(half,half) +inline half fdim(half x, half y) { + if (isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + ? half(detail::binary, 0) + : (x - y); +} + +/// Get NaN value. +/// **See also:** Documentation for +/// [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). \param arg +/// string code \return quiet NaN +inline half nanh(const char *arg) { + unsigned int value = 0x7FFF; + while (*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); +} + +/// \} +/// \anchor exponential +/// \name Exponential functions +/// \{ + +/// Exponential function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). \param arg +/// function argument \return e raised to \a arg \exception FE_INVALID for +/// signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to +/// rounding +inline half exp(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::exp(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, e = (abs >> 10) + (abs <= 0x3FF), exp; + if (!abs) + return half(detail::binary, 0x3C00); + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) + ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if (abs >= 0x4C80) + return half(detail::binary, (arg.data_ & 0x8000) + ? detail::underflow() + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29); + if (e < 14) { + exp = 0; + m >>= 14 - e; + } else { + exp = static_cast(m >> (45 - e)); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + return half(detail::binary, detail::exp2_post( + m, exp, (arg.data_ & 0x8000) != 0, 0, 26)); +#endif +} + +/// Binary exponential. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). \param arg +/// function argument \return 2 raised to \a arg \exception FE_INVALID for +/// signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to +/// rounding +inline half exp2(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::exp2( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, e = (abs >> 10) + (abs <= 0x3FF), + exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); + if (!abs) + return half(detail::binary, 0x3C00); + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) + ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if (abs >= 0x4E40) + return half(detail::binary, (arg.data_ & 0x8000) + ? detail::underflow() + : detail::overflow()); + return half(detail::binary, + detail::exp2_post( + (static_cast(exp) << (6 + e)) & 0x7FFFFFFF, + exp >> (25 - e), (arg.data_ & 0x8000) != 0, 0, 28)); +#endif +} + +/// Exponential minus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% +/// of inputs for `std::round_to_nearest` and in <1% of inputs for any other +/// rounding mode. +/// +/// **See also:** Documentation for +/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). \param +/// arg function argument \return e raised to \a arg and subtracted by 1 +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half expm1(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::expm1( + detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000, + e = (abs >> 10) + (abs <= 0x3FF), exp; + if (!abs) + return arg; + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) + : detail::signal(arg.data_)); + if (abs >= 0x4A00) + return half(detail::binary, + (arg.data_ & 0x8000) + ? detail::rounded(0xBBFF, 1, 1) + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29); + if (e < 14) { + exp = 0; + m >>= 14 - e; + } else { + exp = static_cast(m >> (45 - e)); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if (sign) { + int s = 0; + if (m > 0x80000000) { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - + ((m >> exp) | + ((m & ((static_cast(1) << exp) - 1)) != 0) | s); + exp = 0; + } else + m -= (exp < 31) ? (0x80000000 >> exp) : 1; + for (exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) + ; + if (exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::rounded( + sign + (exp << 10) + static_cast(m >> 21), + static_cast((m >> 20) & 1), (m & 0xFFFFF) != 0)); +#endif +} + +/// Natural logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). \param arg +/// function argument \return logarithm of \a arg to base e \exception +/// FE_INVALID for signaling NaN or negative argument \exception FE_DIVBYZERO +/// for 0 \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if (!abs) + return half(detail::binary, detail::pole(0x8000)); + if (arg.data_ & 0x8000) + return half(detail::binary, (arg.data_ <= 0xFC00) + ? detail::invalid() + : detail::signal(arg.data_)); + if (abs >= 0x7C00) + return (abs == 0x7C00) ? arg + : half(detail::binary, detail::signal(arg.data_)); + for (; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half( + detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 27) + + 8, + exp, 17)); +#endif +} + +/// Common logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). \param +/// arg function argument \return logarithm of \a arg to base 10 \exception +/// FE_INVALID for signaling NaN or negative argument \exception FE_DIVBYZERO +/// for 0 \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log10(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::log10( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if (!abs) + return half(detail::binary, detail::pole(0x8000)); + if (arg.data_ & 0x8000) + return half(detail::binary, (arg.data_ <= 0xFC00) + ? detail::invalid() + : detail::signal(arg.data_)); + if (abs >= 0x7C00) + return (abs == 0x7C00) ? arg + : half(detail::binary, detail::signal(arg.data_)); + switch (abs) { + case 0x4900: + return half(detail::binary, 0x3C00); + case 0x5640: + return half(detail::binary, 0x4000); + case 0x63D0: + return half(detail::binary, 0x4200); + case 0x70E2: + return half(detail::binary, 0x4400); + } + for (; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half( + detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 27) + + 8, + exp, 16)); +#endif +} + +/// Binary logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). \param arg +/// function argument \return logarithm of \a arg to base 2 \exception +/// FE_INVALID for signaling NaN or negative argument \exception FE_DIVBYZERO +/// for 0 \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log2(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::log2( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if (!abs) + return half(detail::binary, detail::pole(0x8000)); + if (arg.data_ & 0x8000) + return half(detail::binary, (arg.data_ <= 0xFC00) + ? detail::invalid() + : detail::signal(arg.data_)); + if (abs >= 0x7C00) + return (abs == 0x7C00) ? arg + : half(detail::binary, detail::signal(arg.data_)); + if (abs == 0x3C00) + return half(detail::binary, 0); + for (; abs < 0x400; abs <<= 1, --exp) + ; + exp += (abs >> 10); + if (!(abs & 0x3FF)) { + unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) + << 6; + for (exp = 18; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, value + (exp << 10) + m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), + m = (((ilog << 27) + + (detail::log2( + static_cast((abs & 0x3FF) | 0x400) + << 20, + 28) >> + 4)) ^ + sign) - + sign; + if (!m) + return half(detail::binary, 0); + for (exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) + ; + for (; m > 0xFFFFFFF; m >>= 1, ++exp) + s |= m & 1; + return half(detail::binary, + detail::fixed2half( + m, exp, sign & 0x8000, s)); +#endif +} + +/// Natural logarithm plus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% +/// of inputs for `std::round_to_nearest` and in ~1% of inputs for any other +/// rounding mode. +/// +/// **See also:** Documentation for +/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). \param +/// arg function argument \return logarithm of \a arg plus 1 to base e +/// \exception FE_INVALID for signaling NaN or argument <-1 +/// \exception FE_DIVBYZERO for -1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log1p(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::log1p( + detail::half2float(arg.data_)))); +#else + if (arg.data_ >= 0xBC00) + return half(detail::binary, (arg.data_ == 0xBC00) ? detail::pole(0x8000) + : (arg.data_ <= 0xFC00) + ? detail::invalid() + : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if (!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + for (; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; + if (arg.data_ & 0x8000) { + m = 0x40000000 - (m >> -exp); + for (exp = 0; m < 0x40000000; m <<= 1, --exp) + ; + } else { + if (exp < 0) { + m = 0x40000000 + (m >> -exp); + exp = 0; + } else { + m += 0x40000000 >> exp; + int i = static_cast(m >> 31); + m >>= i; + exp += i; + } + } + return half(detail::binary, detail::log2_post( + detail::log2(m), exp, 17)); +#endif +} + +/// \} +/// \anchor power +/// \name Power functions +/// \{ + +/// Square root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). \param arg +/// function argument \return square root of \a arg \exception FE_INVALID for +/// signaling NaN and negative arguments \exception FE_INEXACT according to +/// rounding +inline half sqrt(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::sqrt( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 15; + if (!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ > 0x8000) ? detail::invalid() + : arg.data_); + for (; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, + m = detail::sqrt<20>(r, exp += abs >> 10); + return half(detail::binary, detail::rounded( + (exp << 10) + (m & 0x3FF), r > m, r != 0)); +#endif +} + +/// Inverse square root. +/// This function is exact to rounding for all rounding modes and thus generally +/// more accurate than directly computing 1 / sqrt(\a arg) in half-precision, in +/// addition to also being faster. \param arg function argument \return +/// reciprocal of square root of \a arg \exception FE_INVALID for signaling NaN +/// and negative arguments \exception FE_INEXACT according to rounding +inline half rsqrt(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half( + detail::internal_t(1) / + std::sqrt(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, bias = 0x4000; + if (!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ > 0x8000) ? detail::invalid() + : !abs ? detail::pole(arg.data_ & 0x8000) + : 0); + for (; abs < 0x400; abs <<= 1, bias -= 0x400) + ; + unsigned int frac = (abs += bias) & 0x7FF; + if (frac == 0x400) + return half(detail::binary, 0x7A00 - (abs >> 1)); + if ((half::round_style == std::round_to_nearest && + (frac == 0x3FE || frac == 0x76C)) || + (half::round_style != std::round_to_nearest && + (frac == 0x15A || frac == 0x3FC || frac == 0x401 || frac == 0x402 || + frac == 0x67B))) + return pow(arg, half(detail::binary, 0xB800)); + detail::uint32 f = 0x17376 - abs, mx = (abs & 0x3FF) | 0x400, + my = ((f >> 1) & 0x3FF) | 0x400, mz = my * my; + int expy = static_cast(f >> 11) - 31, expx = 32 - (abs >> 10), + i = static_cast(mz >> 21); + for (mz = 0x60000000 - (((mz >> i) * mx) >> (expx - 2 * expy - i)); + mz < 0x40000000; mz <<= 1, --expy) + ; + i = static_cast((my *= mz >> 10) >> 31); + expy += i; + my = (my >> (20 + i)) + 1; + i = static_cast((mz = my * my) >> 21); + for (mz = 0x60000000 - (((mz >> i) * mx) >> (expx - 2 * expy - i)); + mz < 0x40000000; mz <<= 1, --expy) + ; + i = static_cast((my *= (mz >> 10) + 1) >> 31); + return half(detail::binary, + detail::fixed2half( + my >> i, expy + i + 14)); +#endif +} + +/// Cubic root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). \param arg +/// function argument \return cubic root of \a arg \exception FE_INVALID for +/// signaling NaN \exception FE_INEXACT according to rounding +inline half cbrt(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::cbrt( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if (!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + for (; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + (detail::log2( + static_cast((abs & 0x3FF) | 0x400) + << 20, + 24) >> + 4)) ^ + sign) - + sign; + for (exp = 2; m < 0x80000000; m <<= 1, --exp) + ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = static_cast(m >> 31), s; + exp += i; + m <<= 1 - i; + if (exp < 0) { + f = m >> -exp; + exp = 0; + } else { + f = (m << exp) & 0x7FFFFFFF; + exp = static_cast(m >> (31 - exp)); + } + m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); + if (sign) { + if (m > 0x80000000) { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half( + detail::binary, + (half::round_style == std::round_to_nearest) + ? detail::fixed2half( + m, exp + 14, arg.data_ & 0x8000) + : detail::fixed2half( + (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). \param x +/// first argument \param y second argument \return square root of sum of +/// squares without internal over- or underflows \exception FE_INVALID if \a x +/// or \a y is signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT +/// according to rounding of the final square root +inline half hypot(half x, half y) { +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_); +#if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::hypot(fx, fy))); +#else + return half(detail::binary, detail::float2half( + std::sqrt(fx * fx + fy * fy))); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) ? detail::select(0x7C00, y.data_) + : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) + : detail::signal(x.data_, y.data_)); + if (!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if (!absy) + return half(detail::binary, detail::check_underflow(absx)); + if (absy > absx) + std::swap(absx, absy); + for (; absx < 0x400; absx <<= 1, --expx) + ; + for (; absy < 0x400; absy <<= 1, --expy) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = static_cast(mx >> 21), iy = static_cast(my >> 21); + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d < 30) ? ((my >> d) | + ((my & ((static_cast(1) << d) - 1)) != 0)) + : 1; + return half(detail::binary, + detail::hypot_post(mx + my, expx)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). \param x +/// first argument \param y second argument \param z third argument \return +/// square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of +/// the final square root +inline half hypot(half x, half y, half z) { +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); + return half(detail::binary, detail::float2half( + std::sqrt(fx * fx + fy * fy + fz * fz))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, + expx = 0, expy = 0, expz = 0; + if (!absx) + return hypot(y, z); + if (!absy) + return hypot(x, z); + if (!absz) + return hypot(x, y); + if (absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) + ? detail::select(0x7C00, detail::select(y.data_, z.data_)) + : (absy == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, z.data_)) + : (absz == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, y.data_)) + : detail::signal(x.data_, y.data_, z.data_)); + if (absz > absy) + std::swap(absy, absz); + if (absy > absx) + std::swap(absx, absy); + if (absz > absy) + std::swap(absy, absz); + for (; absx < 0x400; absx <<= 1, --expx) + ; + for (; absy < 0x400; absy <<= 1, --expy) + ; + for (; absz < 0x400; absz <<= 1, --expz) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, + mz = (absz & 0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = static_cast(mx >> 21), iy = static_cast(my >> 21), + iz = static_cast(mz >> 21); + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + expz = 2 * (expz + (absz >> 10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d < 30) ? ((mz >> d) | + ((mz & ((static_cast(1) << d) - 1)) != 0)) + : 1; + my += mz; + if (my & 0x80000000) { + my = (my >> 1) | (my & 1); + if (++expy > expx) { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d < 30) ? ((my >> d) | + ((my & ((static_cast(1) << d) - 1)) != 0)) + : 1; + return half(detail::binary, + detail::hypot_post(mx + my, expx)); +#endif +} + +/// Power function. +/// This function may be 1 ULP off the correctly rounded exact result for any +/// rounding mode in ~0.00025% of inputs. +/// +/// **See also:** Documentation for +/// [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). \param x +/// base \param y exponent \return \a x raised to \a y \exception FE_INVALID if +/// \a x or \a y is signaling NaN or if \a x is finite an negative and \a y is +/// finite and not integral \exception FE_DIVBYZERO if \a x is 0 and \a y is +/// negative \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to +/// rounding +inline half pow(half x, half y) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::pow(detail::half2float(x.data_), + detail::half2float(y.data_)))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if (!absy || x.data_ == 0x3C00) + return half( + detail::binary, + detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || + (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); + unsigned int sign = + x.data_ & (static_cast((absy < 0x6800) && is_int && + ((absy >> (25 - (absy >> 10))) & 1)) + << 15); + if (absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy == 0x7C00) + ? ((absx == 0x3C00) ? 0x3C00 + : (!absx && y.data_ == 0xFC00) + ? detail::pole() + : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) + : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); + if (!absx) + return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); + if ((x.data_ & 0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if (x.data_ == 0xBC00) + return half(detail::binary, sign | 0x3C00); + switch (y.data_) { + case 0x3800: + return sqrt(x); + case 0x3C00: + return half(detail::binary, detail::check_underflow(x.data_)); + case 0x4000: + return x * x; + case 0xBC00: + return half(detail::binary, 0x3C00) / x; + } + for (; absx < 0x400; absx <<= 1, --exp) + ; + detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + ((detail::log2( + static_cast((absx & 0x3FF) | 0x400) + << 20) + + 8) >> + 4)) ^ + msign) - + msign; + for (exp = -11; m < 0x80000000; m <<= 1, --exp) + ; + for (; absy < 0x400; absy <<= 1, --exp) + ; + m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) + << 21); + int i = static_cast(m >> 31); + exp += (absy >> 10) + i; + m <<= 1 - i; + if (exp < 0) { + f = m >> -exp; + exp = 0; + } else { + f = (m << exp) & 0x7FFFFFFF; + exp = static_cast(m >> (31 - exp)); + } + return half(detail::binary, + detail::exp2_post( + f, exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); +#endif +} + +/// \} +/// \anchor trigonometric +/// \name Trigonometric functions +/// \{ + +/// Compute sine and cosine simultaneously. +/// This returns the same results as sin() and cos() but is faster than +/// calling each function individually. +/// +/// This function is exact to rounding for all rounding modes. +/// \param arg function argument +/// \param sin variable to take sine of \a arg +/// \param cos variable to take cosine of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline void sincos(half arg, half *sin, half *cos) { +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = + half(detail::binary, detail::float2half(std::sin(f))); + *cos = + half(detail::binary, detail::float2half(std::cos(f))); +#else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if (abs >= 0x7C00) + *sin = *cos = + half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if (!abs) { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } else if (abs < 0x2500) { + *sin = half(detail::binary, + detail::rounded(arg.data_ - 1, 1, 1)); + *cos = half(detail::binary, + detail::rounded(0x3BFF, 1, 1)); + } else { + if (half::round_style != std::round_to_nearest) { + switch (abs) { + case 0x48B7: + *sin = half(detail::binary, detail::rounded( + (~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + *cos = half(detail::binary, + detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half(detail::binary, detail::rounded( + (arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + *cos = half(detail::binary, + detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half(detail::binary, detail::rounded( + (~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + *cos = half(detail::binary, + detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half(detail::binary, detail::rounded( + (arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + *cos = half(detail::binary, + detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + switch (k & 3) { + case 1: + sc = std::make_pair(sc.second, -sc.first); + break; + case 2: + sc = std::make_pair(-sc.first, -sc.second); + break; + case 3: + sc = std::make_pair(-sc.second, sc.first); + break; + } + *sin = half(detail::binary, + detail::fixed2half( + (sc.first ^ -static_cast(sign)) + sign)); + *cos = half( + detail::binary, + detail::fixed2half(sc.second)); + } +#endif +} + +/// Sine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). \param arg +/// function argument \return sine value of \a arg \exception FE_INVALID for +/// signaling NaN or infinity \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT +/// according to rounding +inline half sin(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sin(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if (!abs) + return arg; + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? detail::invalid() + : detail::signal(arg.data_)); + if (abs < 0x2900) + return half(detail::binary, + detail::rounded(arg.data_ - 1, 1, 1)); + if (half::round_style != std::round_to_nearest) + switch (abs) { + case 0x48B7: + return half(detail::binary, detail::rounded( + (~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + case 0x6A64: + return half(detail::binary, detail::rounded( + (~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + case 0x6D8C: + return half(detail::binary, detail::rounded( + (arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = + -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); +#endif +} + +/// Cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). \param arg +/// function argument \return cosine value of \a arg \exception FE_INVALID for +/// signaling NaN or infinity \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT +/// according to rounding +inline half cos(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cos(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if (!abs) + return half(detail::binary, 0x3C00); + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? detail::invalid() + : detail::signal(arg.data_)); + if (abs < 0x2500) + return half(detail::binary, + detail::rounded(0x3BFF, 1, 1)); + if (half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, + detail::rounded(0x80FC, 1, 1)); + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); +#endif +} + +/// Tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). \param arg +/// function argument \return tangent value of \a arg \exception FE_INVALID for +/// signaling NaN or infinity \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT +/// according to rounding +inline half tan(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tan(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if (!abs) + return arg; + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? detail::invalid() + : detail::signal(arg.data_)); + if (abs < 0x2700) + return half(detail::binary, + detail::rounded(arg.data_, 0, 1)); + if (half::round_style != std::round_to_nearest) + switch (abs) { + case 0x658C: + return half(detail::binary, detail::rounded( + (arg.data_ & 0x8000) | 0x07E6, 1, 1)); + case 0x7330: + return half(detail::binary, detail::rounded( + (~arg.data_ & 0x8000) | 0x4B62, 1, 1)); + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 30); + if (k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), + signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first ^ signy) - signy, + mx = (sc.second ^ signx) - signx; + for (; my < 0x80000000; my <<= 1, --exp) + ; + for (; mx < 0x80000000; mx <<= 1, ++exp) + ; + return half(detail::binary, + detail::tangent_post( + my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); +#endif +} + +/// Arc sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). \param arg +/// function argument \return arc sine value of \a arg \exception FE_INVALID for +/// signaling NaN or if abs(\a arg) > 1 \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half asin(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::asin( + detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if (!abs) + return arg; + if (abs >= 0x3C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : detail::rounded( + sign | 0x3E48, 0, 1)); + if (abs < 0x2900) + return half(detail::binary, + detail::rounded(arg.data_, 0, 1)); + if (half::round_style != std::round_to_nearest && + (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, + detail::rounded(arg.data_ + 1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = + detail::atan2(sc.first, sc.second, + (half::round_style == std::round_to_nearest) ? 27 : 26); + return half(detail::binary, + detail::fixed2half( + m, 14, sign)); +#endif +} + +/// Arc cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). \param arg +/// function argument \return arc cosine value of \a arg \exception FE_INVALID +/// for signaling NaN or if abs(\a arg) > 1 \exception FE_OVERFLOW, +/// ...UNDERFLOW, ...INEXACT according to rounding +inline half acos(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::acos( + detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if (!abs) + return half(detail::binary, + detail::rounded(0x3E48, 0, 1)); + if (abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) + : (abs > 0x3C00) ? detail::invalid() + : sign ? detail::rounded(0x4248, 0, 1) + : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, + detail::fixed2half( + sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); +#endif +} + +/// Arc tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). \param arg +/// function argument \return arc tangent value of \a arg \exception FE_INVALID +/// for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according +/// to rounding +inline half atan(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::atan( + detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if (!abs) + return arg; + if (abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) + ? detail::rounded( + sign | 0x3E48, 0, 1) + : detail::signal(arg.data_)); + if (abs <= 0x2700) + return half(detail::binary, + detail::rounded(arg.data_ - 1, 1, 1)); + int exp = (abs >> 10) + (abs <= 0x3FF); + detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + detail::uint32 m = + (exp > 15) + ? detail::atan2(my << 19, 0x20000000 >> (exp - 15), + (half::round_style == std::round_to_nearest) ? 26 + : 24) + : detail::atan2(my << (exp + 4), 0x20000000, + (half::round_style == std::round_to_nearest) ? 30 + : 28); + return half(detail::binary, + detail::fixed2half( + m, 14, sign)); +#endif +} + +/// Arc tangent function. +/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% +/// of inputs for `std::round_to_nearest`, in ~0.1% of inputs for +/// `std::round_toward_zero` and in ~0.02% of inputs for any other rounding +/// mode. +/// +/// **See also:** Documentation for +/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). \param y +/// numerator \param x denominator \return arc tangent value \exception +/// FE_INVALID if \a x or \a y is signaling NaN \exception FE_OVERFLOW, +/// ...UNDERFLOW, ...INEXACT according to rounding +inline half atan2(half y, half x) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan2(detail::half2float(y.data_), + detail::half2float(x.data_)))); +#else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, + signx = x.data_ >> 15, signy = y.data_ & 0x8000; + if (absx >= 0x7C00 || absy >= 0x7C00) { + if (absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if (absy == 0x7C00) + return half( + detail::binary, + (absx < 0x7C00) + ? detail::rounded(signy | 0x3E48, 0, 1) + : signx + ? detail::rounded(signy | 0x40B6, 0, 1) + : detail::rounded(signy | 0x3A48, 0, 1)); + return (x.data_ == 0x7C00) + ? half(detail::binary, signy) + : half(detail::binary, detail::rounded( + signy | 0x4248, 0, 1)); + } + if (!absy) + return signx + ? half(detail::binary, detail::rounded( + signy | 0x4248, 0, 1)) + : y; + if (!absx) + return half(detail::binary, + detail::rounded(signy | 0x3E48, 0, 1)); + int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); + if (d > (signx ? 18 : 12)) + return half(detail::binary, + detail::rounded(signy | 0x3E48, 0, 1)); + if (signx && d < -11) + return half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)); + if (!signx && + d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) { + for (; absy < 0x400; absy <<= 1, --d) + ; + detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, + my = ((absy << 1) & 0x7FF) | 0x800; + int i = my < mx; + d -= i; + if (d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, + detail::fixed2half( + my / mx, d + 14, signy, my % mx != 0)); + } + detail::uint32 m = detail::atan2( + ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d + : (d > 0) ? 0 + : -1)), + ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d + : (d < 0) ? 0 + : 1))); + return half(detail::binary, + detail::fixed2half( + signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); +#endif +} + +/// \} +/// \anchor hyperbolic +/// \name Hyperbolic functions +/// \{ + +/// Hyperbolic sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). \param arg +/// function argument \return hyperbolic sine value of \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half sinh(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::sinh( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if (!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + if (abs <= 0x2900) + return half(detail::binary, + detail::rounded(arg.data_, 0, 1)); + std::pair mm = detail::hyperbolic_args( + abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for (exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) + ; + unsigned int sign = arg.data_ & 0x8000; + if (exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, + detail::fixed2half( + m, exp, sign)); +#endif +} + +/// Hyperbolic cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). \param arg +/// function argument \return hyperbolic cosine value of \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half cosh(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::cosh( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if (!abs) + return half(detail::binary, 0x3C00); + if (abs >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = detail::hyperbolic_args( + abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second; + int i = static_cast((~m & 0xFFFFFFFF) >> 31); + m = (m >> i) | (m & i) | 0x80000000; + if ((exp += 13 + i) > 29) + return half(detail::binary, detail::overflow()); + return half( + detail::binary, + detail::fixed2half(m, exp)); +#endif +} + +/// Hyperbolic tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). \param arg +/// function argument \return hyperbolic tangent value of \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half tanh(half arg) { +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half(std::tanh( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if (!abs) + return arg; + if (abs >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ - 0x4000)); + if (abs >= 0x4500) + return half(detail::binary, detail::rounded( + (arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + if (abs < 0x2700) + return half(detail::binary, + detail::rounded(arg.data_ - 1, 1, 1)); + if (half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, + detail::rounded(arg.data_ - 3, 0, 1)); + std::pair mm = + detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - + (half::round_style != std::round_to_nearest), + mx = mm.first + mm.second; + int i = static_cast((~mx & 0xFFFFFFFF) >> 31); + for (exp = 13; my < 0x80000000; my <<= 1, --exp) + ; + mx = (mx >> i) | 0x80000000; + return half(detail::binary, detail::tangent_post( + my, mx, exp - i, arg.data_ & 0x8000)); +#endif +} + +/// Hyperbolic area sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). \param +/// arg function argument \return area sine value of \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half asinh(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::asinh( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if (!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + if (abs <= 0x2900) + return half(detail::binary, + detail::rounded(arg.data_ - 1, 1, 1)); + if (half::round_style != std::round_to_nearest) + switch (abs) { + case 0x32D4: + return half(detail::binary, detail::rounded( + arg.data_ - 13, 1, 1)); + case 0x3B5B: + return half(detail::binary, detail::rounded( + arg.data_ - 197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). \param +/// arg function argument \return area cosine value of \a arg \exception +/// FE_INVALID for signaling NaN or arguments <1 \exception FE_OVERFLOW, +/// ...UNDERFLOW, ...INEXACT according to rounding +inline half acosh(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::acosh( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if ((arg.data_ & 0x8000) || abs < 0x3C00) + return half(detail::binary, (abs <= 0x7C00) ? detail::invalid() + : detail::signal(arg.data_)); + if (abs == 0x3C00) + return half(detail::binary, 0); + if (arg.data_ >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + return half(detail::binary, + detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). \param +/// arg function argument \return area tangent value of \a arg \exception +/// FE_INVALID for signaling NaN or if abs(\a arg) > 1 \exception FE_DIVBYZERO +/// for +/-1 \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to +/// rounding +inline half atanh(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::atanh( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 0; + if (!abs) + return arg; + if (abs >= 0x3C00) + return half(detail::binary, (abs == 0x3C00) + ? detail::pole(arg.data_ & 0x8000) + : (abs <= 0x7C00) ? detail::invalid() + : detail::signal(arg.data_)); + if (abs < 0x2700) + return half(detail::binary, + detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs & 0x3FF) | + ((abs > 0x3FF) << 10)) + << ((abs >> 10) + (abs <= 0x3FF) + 6), + my = 0x80000000 + m, mx = 0x80000000 - m; + for (; mx < 0x80000000; mx <<= 1, ++exp) + ; + int i = my >= mx, s; + return half( + detail::binary, + detail::log2_post( + detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, + exp + i - 1, 16, arg.data_ & 0x8000)); +#endif +} + +/// \} +/// \anchor special +/// \name Error and gamma functions +/// \{ + +/// Error function. +/// This function may be 1 ULP off the correctly rounded exact result for any +/// rounding mode in <0.5% of inputs. +/// +/// **See also:** Documentation for +/// [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). \param arg +/// function argument \return error function value of \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half erf(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erf(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if (!abs || abs >= 0x7C00) + return (abs >= 0x7C00) ? half(detail::binary, + (abs == 0x7C00) ? (arg.data_ - 0x4000) + : detail::signal(arg.data_)) + : arg; + if (abs >= 0x4200) + return half(detail::binary, detail::rounded( + (arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Complementary error function. +/// This function may be 1 ULP off the correctly rounded exact result for any +/// rounding mode in <0.5% of inputs. +/// +/// **See also:** Documentation for +/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). \param arg +/// function argument \return 1 minus error function value of \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +inline half erfc(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::erfc( + detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if (abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, + (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) + : arg; + if (!abs) + return half(detail::binary, 0x3C00); + if (abs >= 0x4400) + return half(detail::binary, detail::rounded( + (sign >> 1) - (sign >> 15), sign >> 15, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Natural logarithm of gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any +/// rounding mode in ~0.025% of inputs. +/// +/// **See also:** Documentation for +/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). \param +/// arg function argument \return natural logarith of gamma function for \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 or negative integer arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half lgamma(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::lgamma( + detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if (abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if (!abs || arg.data_ >= 0xE400 || + (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::pole()); + if (arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, + detail::gamma(arg.data_)); +#endif +} + +/// Gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any +/// rounding mode in <0.25% of inputs. +/// +/// **See also:** Documentation for +/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). \param +/// arg function argument \return gamma function value of \a arg \exception +/// FE_INVALID for signaling NaN, negative infinity or negative integer +/// arguments \exception FE_DIVBYZERO for 0 \exception FE_OVERFLOW, +/// ...UNDERFLOW, ...INEXACT according to rounding +inline half tgamma(half arg) { +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half(std::tgamma( + detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if (!abs) + return half(detail::binary, detail::pole(arg.data_)); + if (abs >= 0x7C00) + return (arg.data_ == 0x7C00) + ? arg + : half(detail::binary, detail::signal(arg.data_)); + if (arg.data_ >= 0xE400 || + (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::invalid()); + if (arg.data_ >= 0xCA80) + return half(detail::binary, + detail::underflow( + (1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); + if (arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if (arg.data_ == 0x3C00) + return arg; + return half(detail::binary, + detail::gamma(arg.data_)); +#endif +} + +/// \} +/// \anchor rounding +/// \name Rounding +/// \{ + +/// Nearest integer not less than half value. +/// **See also:** Documentation for +/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). \param arg +/// half to round \return nearest integer not less than \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_INEXACT if value had to be +/// rounded +inline half ceil(half arg) { + return half( + detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater than half value. +/// **See also:** Documentation for +/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). \param +/// arg half to round \return nearest integer not greater than \a arg \exception +/// FE_INVALID for signaling NaN \exception FE_INEXACT if value had to be +/// rounded +inline half floor(half arg) { + return half( + detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater in magnitude than half value. +/// **See also:** Documentation for +/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). \param +/// arg half to round \return nearest integer not greater in magnitude than \a +/// arg \exception FE_INVALID for signaling NaN \exception FE_INEXACT if value +/// had to be rounded +inline half trunc(half arg) { + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). \param +/// arg half to round \return nearest integer, rounded away from zero in +/// half-way cases \exception FE_INVALID for signaling NaN \exception FE_INEXACT +/// if value had to be rounded +inline half round(half arg) { + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). \param +/// arg half to round \return nearest integer, rounded away from zero in +/// half-way cases \exception FE_INVALID if value is not representable as `long` +inline long lround(half arg) { + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). \param arg +/// half expression to round \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half rint(half arg) { + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). \param +/// arg half expression to round \return nearest integer using default rounding +/// mode \exception FE_INVALID if value is not representable as `long` +/// \exception FE_INEXACT if value had to be rounded +inline long lrint(half arg) { + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +inline half nearbyint(half arg) { + return half(detail::binary, + detail::integral(arg.data_)); +} +#if HALF_ENABLE_CPP11_LONG_LONG +/// Nearest integer. +/// **See also:** Documentation for +/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). \param +/// arg half to round \return nearest integer, rounded away from zero in +/// half-way cases \exception FE_INVALID if value is not representable as `long +/// long` +inline long long llround(half arg) { + return detail::half2int( + arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). \param +/// arg half expression to round \return nearest integer using default rounding +/// mode \exception FE_INVALID if value is not representable as `long long` +/// \exception FE_INEXACT if value had to be rounded +inline long long llrint(half arg) { + return detail::half2int(arg.data_); +} +#endif + +/// \} +/// \anchor float +/// \name Floating point manipulation +/// \{ + +/// Decompress floating-point number. +/// **See also:** Documentation for +/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). \param +/// arg number to decompress \param exp address to store exponent at \return +/// significant in range [0.5, 1) \exception FE_INVALID for signaling NaN +inline half frexp(half arg, int *exp) { + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if (abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + for (; abs < 0x400; abs <<= 1, --*exp) + ; + *exp += (abs >> 10) - 14; + return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multiplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbln(half arg, long exp) { + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if (abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) + : arg; + for (; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + if (exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if (exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if (exp > 0) + return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); + unsigned int m = (abs & 0x3FF) | 0x400; + return half(detail::binary, detail::rounded( + sign | (m >> (1 - exp)), (m >> -exp) & 1, + (m & ((1 << -exp) - 1)) != 0)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). \param +/// arg number to modify \param exp power of two to multiply with \return \a arg +/// multiplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). \param +/// arg number to modify \param exp power of two to multiply with \return \a arg +/// multiplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + +/// Extract integer and fractional parts. +/// **See also:** Documentation for +/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). \param arg +/// number to decompress \param iptr address to store integer part at \return +/// fractional part \exception FE_INVALID for signaling NaN +inline half modf(half arg, half *iptr) { + unsigned int abs = arg.data_ & 0x7FFF; + if (abs > 0x7C00) { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if (abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); + if (abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, + m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if (!m) + return half(detail::binary, arg.data_ & 0x8000); + for (; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). \param +/// arg number to query \return floating-point exponent \retval FP_ILOGB0 for +/// zero \retval FP_ILOGBNAN for NaN \retval INT_MAX for infinity \exception +/// FE_INVALID for 0 or infinite values +inline int ilogb(half arg) { + int abs = arg.data_ & 0x7FFF, exp; + if (!abs || abs >= 0x7C00) { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for (exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + return exp; +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). \param arg +/// number to query \return floating-point exponent \exception FE_INVALID for +/// signaling NaN \exception FE_DIVBYZERO for 0 +inline half logb(half arg) { + int abs = arg.data_ & 0x7FFF, exp; + if (!abs) + return half(detail::binary, detail::pole(0x8000)); + if (abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for (exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + unsigned int value = static_cast(exp < 0) << 15; + if (exp) { + unsigned int m = std::abs(exp) << 6; + for (exp = 18; m < 0x400; m <<= 1, --exp) + ; + value |= (exp << 10) + m; + } + return half(detail::binary, value); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nextafter(half from, half to) { + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if (fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if (from.data_ == to.data_ || !(fabs | tabs)) + return to; + if (!fabs) { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_ & 0x8000) + 1); + } + unsigned int out = + from.data_ + + (((from.data_ >> 15) ^ + static_cast( + (from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < + (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) + << 1) - + 1; + detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && + (out & 0x7C00) < 0x400); + return half(detail::binary, out); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nexttoward(half from, long double to) { + int fabs = from.data_ & 0x7FFF; + if (fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if (detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if (!fabs) { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, + (static_cast(detail::builtin_signbit(to)) << 15) + 1); + } + unsigned int out = + from.data_ + + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; + detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && + (out & 0x7FFF) < 0x400); + return half(detail::binary, out); +} + +/// Take sign. +/// **See also:** Documentation for +/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). +/// \param x value to change sign for +/// \param y value to take sign from +/// \return value equal to \a x in magnitude and to \a y in sign +inline HALF_CONSTEXPR half copysign(half x, half y) { + return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); +} + +/// \} +/// \anchor classification +/// \name Floating point classification +/// \{ + +/// Classify floating-point value. +/// **See also:** Documentation for +/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). +/// \param arg number to classify +/// \retval FP_ZERO for positive and negative zero +/// \retval FP_SUBNORMAL for subnormal numbers +/// \retval FP_INFINITY for positive and negative infinity +/// \retval FP_NAN for NaNs +/// \retval FP_NORMAL for all other (normal) values +inline HALF_CONSTEXPR int fpclassify(half arg) { + return !(arg.data_ & 0x7FFF) ? FP_ZERO + : ((arg.data_ & 0x7FFF) < 0x400) ? FP_SUBNORMAL + : ((arg.data_ & 0x7FFF) < 0x7C00) ? FP_NORMAL + : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE + : FP_NAN; +} + +/// Check if finite number. +/// **See also:** Documentation for +/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). +/// \param arg number to check +/// \retval true if neither infinity nor NaN +/// \retval false else +inline HALF_CONSTEXPR bool isfinite(half arg) { + return (arg.data_ & 0x7C00) != 0x7C00; +} + +/// Check for infinity. +/// **See also:** Documentation for +/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). \param +/// arg number to check \retval true for positive or negative infinity \retval +/// false else +inline HALF_CONSTEXPR bool isinf(half arg) { + return (arg.data_ & 0x7FFF) == 0x7C00; +} + +/// Check for NaN. +/// **See also:** Documentation for +/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). \param +/// arg number to check \retval true for NaNs \retval false else +inline HALF_CONSTEXPR bool isnan(half arg) { + return (arg.data_ & 0x7FFF) > 0x7C00; +} + +/// Check if normal number. +/// **See also:** Documentation for +/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). +/// \param arg number to check +/// \retval true if normal number +/// \retval false if either subnormal, zero, infinity or NaN +inline HALF_CONSTEXPR bool isnormal(half arg) { + return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); +} + +/// Check sign. +/// **See also:** Documentation for +/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). +/// \param arg number to check +/// \retval true for negative number +/// \retval false for positive number +inline HALF_CONSTEXPR bool signbit(half arg) { + return (arg.data_ & 0x8000) != 0; +} + +/// \} +/// \anchor compfunc +/// \name Comparison +/// \{ + +/// Quiet comparison for greater than. +/// **See also:** Documentation for +/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreater(half x, half y) { + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for greater equal. +/// **See also:** Documentation for +/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) { + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less than. +/// **See also:** Documentation for +/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). \param +/// x first operand \param y second operand \retval true if \a x less than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isless(half x, half y) { + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less equal. +/// **See also:** Documentation for +/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool islessequal(half x, half y) { + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less or greater. +/// **See also:** Documentation for +/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if either less or greater +/// \retval false else +inline HALF_CONSTEXPR bool islessgreater(half x, half y) { + return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && + !isnan(y); +} + +/// Quiet check if unordered. +/// **See also:** Documentation for +/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). +/// \param x first operand +/// \param y second operand +/// \retval true if unordered (one or two NaN operands) +/// \retval false else +inline HALF_CONSTEXPR bool isunordered(half x, half y) { + return isnan(x) || isnan(y); +} + +/// \} +/// \anchor casting +/// \name Casting +/// \{ + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic +/// type. The values are converted directly using the default rounding mode, +/// without any roundtrip over `float` that a `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref +/// half_float::half) or with any of the two types not being a built-in +/// arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler error and casting between [half](\ref +/// half_float::half)s returns the argument unmodified. \tparam T destination +/// type (half or built-in arithmetic type) \tparam U source type (half or +/// built-in arithmetic type) \param arg value to cast \return \a arg converted +/// to destination type \exception FE_INVALID if \a T is integer type and result +/// is not representable as \a T \exception FE_OVERFLOW, ...UNDERFLOW, +/// ...INEXACT according to rounding +template T half_cast(U arg) { + return detail::half_caster::cast(arg); +} + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic +/// type. The values are converted directly using the specified rounding mode, +/// without any roundtrip over `float` that a `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref +/// half_float::half) or with any of the two types not being a built-in +/// arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler error and casting between [half](\ref +/// half_float::half)s returns the argument unmodified. \tparam T destination +/// type (half or built-in arithmetic type) \tparam R rounding mode to use. +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not +/// representable as \a T \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT +/// according to rounding +template T half_cast(U arg) { + return detail::half_caster::cast(arg); +} +/// \} + +/// \} +/// \anchor errors +/// \name Error handling +/// \{ + +/// Clear exception flags. +/// This function works even if [automatic exception flag handling](\ref +/// HALF_ERRHANDLING_FLAGS) is disabled, but in that case manual flag management +/// is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). +/// \param excepts OR of exceptions to clear +/// \retval 0 all selected flags cleared successfully +inline int feclearexcept(int excepts) { + detail::errflags() &= ~excepts; + return 0; +} + +/// Test exception flags. +/// This function works even if [automatic exception flag handling](\ref +/// HALF_ERRHANDLING_FLAGS) is disabled, but in that case manual flag management +/// is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). +/// \param excepts OR of exceptions to test +/// \return OR of selected exceptions if raised +inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + +/// Raise exception flags. +/// This raises the specified floating point exceptions and also invokes any +/// additional automatic exception handling as configured with the +/// [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. +/// This function works even if [automatic exception flag handling](\ref +/// HALF_ERRHANDLING_FLAGS) is disabled, but in that case manual flag management +/// is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). +/// \param excepts OR of exceptions to raise +/// \retval 0 all selected exceptions raised successfully +inline int feraiseexcept(int excepts) { + detail::errflags() |= excepts; + detail::raise(excepts); + return 0; +} + +/// Save exception flags. +/// This function works even if [automatic exception flag handling](\ref +/// HALF_ERRHANDLING_FLAGS) is disabled, but in that case manual flag management +/// is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp address to store flag state at +/// \param excepts OR of flags to save +/// \retval 0 for success +inline int fegetexceptflag(int *flagp, int excepts) { + *flagp = detail::errflags() & excepts; + return 0; +} + +/// Restore exception flags. +/// This only copies the specified exception state (including unset flags) +/// without incurring any additional exception handling. This function works +/// even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, but in that case manual flag management is the only way to raise +/// flags. +/// +/// **See also:** Documentation for +/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp address to take flag state from +/// \param excepts OR of flags to restore +/// \retval 0 for success +inline int fesetexceptflag(const int *flagp, int excepts) { + detail::errflags() = + (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); + return 0; +} + +/// Throw C++ exceptions based on set exception flags. +/// This function manually throws a corresponding C++ exception if one of the +/// specified flags is set, no matter if automatic throwing (via +/// [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID)) is +/// enabled or not. This function works even if [automatic exception flag +/// handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, but in that case manual +/// flag management is the only way to raise flags. \param excepts OR of +/// exceptions to test \param msg error message to use for exception description +/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and +/// set \throw std::overflow_error if `FE_OVERFLOW` is selected and set \throw +/// std::underflow_error if `FE_UNDERFLOW` is selected and set \throw +/// std::range_error if `FE_INEXACT` is selected and set +inline void fethrowexcept(int excepts, const char *msg = "") { + excepts &= detail::errflags(); + if (excepts & (FE_INVALID | FE_DIVBYZERO)) + throw std::domain_error(msg); + if (excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if (excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if (excepts & FE_INEXACT) + throw std::range_error(msg); +} +/// \} +} // namespace half_float + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS +#pragma warning(pop) +#undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/tilelang/original/src/tl_templates/cpu/common.h b/tilelang/original/src/tl_templates/cpu/common.h new file mode 100644 index 0000000000000000000000000000000000000000..b288cd114bb254791f6aaf8396616002b405ebaa --- /dev/null +++ b/tilelang/original/src/tl_templates/cpu/common.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include + +// Not Implemented +F \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/cpu/gemm.h b/tilelang/original/src/tl_templates/cpu/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..1d8fbb7e2a9d460cef00585313abd89368858824 --- /dev/null +++ b/tilelang/original/src/tl_templates/cpu/gemm.h @@ -0,0 +1,3 @@ +#pragma once + +// Not Implemented diff --git a/tilelang/original/src/tl_templates/cuda/atomic.h b/tilelang/original/src/tl_templates/cuda/atomic.h new file mode 100644 index 0000000000000000000000000000000000000000..f6096cc9d5d280b9e3e031e1149d4e2098966766 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/atomic.h @@ -0,0 +1,711 @@ +#pragma once + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include +#include +#include + +using cutlass::bfloat16_t; +using cutlass::half_t; + +#define TL_DEVICE __forceinline__ __device__ +#define TL_NOT_IMPLEMENTED() \ + { \ + printf("%s not implemented\n", __PRETTY_FUNCTION__); \ + asm volatile("brkpt;\n"); \ + } +template struct normalize_atomic_type { + using type = T; +}; + +template <> struct normalize_atomic_type { + using type = half; +}; + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) +template <> struct normalize_atomic_type { + using type = __nv_bfloat16; +}; +#endif + +template TL_DEVICE T1 cuda_cast(T2 val) { + return T1(val); +} + +template <> TL_DEVICE half cuda_cast(float val) { + return __float2half(val); +} + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) +template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} +#endif + +template +TL_DEVICE void AtomicMax(T1 *ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = ref; + if constexpr (std::is_same_v || + std::is_same_v) { + // There is no implementation of atomicMax for half and bf16 in cuda. + // We simulate this process by atomicCAS loop. + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val > *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } + } else { +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*address); + aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif + } +} + +template +TL_DEVICE T1 AtomicMaxRet(T1 *ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = ref; + if constexpr (std::is_same_v || + std::is_same_v) { + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val > *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } + return static_cast(*reinterpret_cast(&old_val_ushort)); + } else { +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*address); + return static_cast( + aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif + } +} + +template +TL_DEVICE void AtomicMin(T1 *ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = ref; + if constexpr (std::is_same_v || + std::is_same_v) { + // There is no implementation of atomicMin for half and bf16 in cuda. + // We simulate this process by atomicCAS loop. + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val < *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } + } else { +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*address); + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif + } +} + +template +TL_DEVICE T1 AtomicMinRet(T1 *ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = ref; + if constexpr (std::is_same_v || + std::is_same_v) { + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val < *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } + return static_cast(*reinterpret_cast(&old_val_ushort)); + } else { +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*address); + return static_cast( + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif + } +} + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 890)) +template +TL_DEVICE void AtomicAdd(T1 *address, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + if constexpr (std::is_same_v || + std::is_same_v) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(address), static_cast(val)); + } else { + // Since atomic ref do not support memory order, we need to inline ptx + // code here for each situation + if constexpr (std::is_same_v) { + // fp16 + __half ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + } else if constexpr (std::is_same_v) { + // bf16 + __nv_bfloat16 ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + } + } + } else { + atomicAdd(reinterpret_cast(address), cuda_cast(val)); + } +} +#else +template +TL_DEVICE void AtomicAdd(T1 *address, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + (void)memory_order; + atomicAdd(reinterpret_cast(address), cuda_cast(val)); +} +#endif + +template +TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + if constexpr (std::is_same_v || + std::is_same_v) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return static_cast( + atomicAdd(reinterpret_cast(address), static_cast(val))); + } else { + if constexpr (std::is_same_v) { + // fp16 + __half ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + return static_cast(*reinterpret_cast<__half *>(&ret_val_cast)); + } else if constexpr (std::is_same_v) { + // bf16 + __nv_bfloat16 ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + return static_cast( + *reinterpret_cast<__nv_bfloat16 *>(&ret_val_cast)); + } + } + } else { +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*address); + return static_cast( + aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif + } +} + +// TODO add memory_order for vectorized atomic add +TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + // Since atomicAdd does not support memory order, atomic_ref does not + // support vectorized atomic operation we can only inline ptx code here + // Note: Vectorized atomic operations only support global space + // Note: for 16-bit value, we need to reinterpret_cast the value to unsigned + // short and use "h" register in assembly + __half2 add_val = *reinterpret_cast<__half2 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __half ret_val_x, ret_val_y; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + } +} + +TL_DEVICE half2 +AtomicAddx2Ret(half_t *ref, half_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + __half2 add_val = *reinterpret_cast<__half2 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __half ret_val_x, ret_val_y; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + return half2(*reinterpret_cast<__half *>(&ret_val_x_cast), + *reinterpret_cast<__half *>(&ret_val_y_cast)); + } +} + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) +TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd( + reinterpret_cast<__nv_bfloat162 *>(ref), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + } else { + __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __nv_bfloat162 ret_val; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + } +} + +TL_DEVICE __nv_bfloat162 +AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd( + reinterpret_cast<__nv_bfloat162 *>(ref), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + } else { + __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __nv_bfloat162 ret_val; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + return __nv_bfloat162(*reinterpret_cast<__nv_bfloat16 *>(&ret_val_x_cast), + *reinterpret_cast<__nv_bfloat16 *>(&ret_val_y_cast)); + } +} +#endif + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +TL_DEVICE void AtomicAddx2(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float2 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float2 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } + } +} + +TL_DEVICE float2 +AtomicAddx2Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float2 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float2 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } + return ret_val; + } +} + +TL_DEVICE void AtomicAddx4(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + // Since atomicAdd does not support memory order, atomic_ref does not + // support vectorized atomic operation we can only inline ptx code here + // Note: Vectorized atomic operations only support global space + float4 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float4 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } + } +} + +TL_DEVICE float4 +AtomicAddx4Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float4 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float4 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.global.gpu.release.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.global.gpu.acquire.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.global.gpu.acq_rel.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } + return ret_val; + } +} +#else +TL_DEVICE void AtomicAddx2(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float2 add_val = *reinterpret_cast(val); + atomicAdd(ref + 0, add_val.x); + atomicAdd(ref + 1, add_val.y); +} + +TL_DEVICE float2 +AtomicAddx2Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float2 add_val = *reinterpret_cast(val); + float2 ret; + ret.x = atomicAdd(ref + 0, add_val.x); + ret.y = atomicAdd(ref + 1, add_val.y); + return ret; +} + +TL_DEVICE void AtomicAddx4(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float4 add_val = *reinterpret_cast(val); + atomicAdd(ref + 0, add_val.x); + atomicAdd(ref + 1, add_val.y); + atomicAdd(ref + 2, add_val.z); + atomicAdd(ref + 3, add_val.w); +} + +TL_DEVICE float4 +AtomicAddx4Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float4 add_val = *reinterpret_cast(val); + float4 ret; + ret.x = atomicAdd(ref + 0, add_val.x); + ret.y = atomicAdd(ref + 1, add_val.y); + ret.z = atomicAdd(ref + 2, add_val.z); + ret.w = atomicAdd(ref + 3, add_val.w); + return ret; +} +#endif + +template TL_DEVICE T AtomicLoad(T *ref, int memory_order) { +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*ref); + return aref.load(cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif +} + +template +TL_DEVICE void AtomicStore(T1 *ref, T2 value, int memory_order) { + using NT1 = typename normalize_atomic_type::type; +#if CUDART_VERSION >= 11080 + cuda::atomic_ref aref(*ref); + aref.store(cuda_cast(value), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif +} diff --git a/tilelang/original/src/tl_templates/cuda/barrier.h b/tilelang/original/src/tl_templates/cuda/barrier.h new file mode 100644 index 0000000000000000000000000000000000000000..79a57f7df1b87b581c003bcb685faa29f67b8cef --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/barrier.h @@ -0,0 +1,162 @@ +#pragma once + +#include "common.h" +#include + +// Reuse cutlass advanced barrier abstraction +using Barrier = cutlass::arch::ClusterTransactionBarrier; + +namespace tl { + +TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.init.shared.b64 [%1], %0;" + : + : "r"(arrive_count), "r"(smem_int_ptr)); +} + +TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) { + + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + uint32_t waitComplete; + + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_int_ptr), "r"(phase_bit)); + + return waitComplete; +} + +TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) { + if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + // Arbitrarily large timer value after which try-wait expires and re-tries. + uint32_t ticks = 0x989680; + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks)); + } +} + +TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures + // to save instruction issue slots + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(smem_int_ptr), + "r"(phase_bit)); +} + +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr)); +} + +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, + uint32_t pred) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + if (pred) { + asm volatile("{\n\t" + ".reg .b32 remAddr32;\n\t" + "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_int_ptr), "r"(cta_id)); + } +} + +TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;" + : + : "r"(transaction_bytes), "r"(smem_int_ptr)); +} + +TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;" + : + : "r"(transaction_bytes), "r"(smem_int_ptr)); +} + +template +TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) { + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];" + : + : "r"(smem_int_mbar)); +} + +template +TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar) { + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + asm volatile("{\n\t" + "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t" + "}" + : + : "r"(smem_int_mbar)); + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar); +} + +TL_DEVICE void fence_proxy_async() { + asm volatile("fence.proxy.async.shared::cta;" : :); +} + +TL_DEVICE void fence_barrier_init() { + asm volatile("fence.mbarrier_init.release.cluster;" : :); +} + +// Indicate arrival of warp issuing TMA_STORE +TL_DEVICE void tma_store_arrive() { + asm volatile("cp.async.bulk.commit_group;"); +} + +template TL_DEVICE void tma_store_wait() { + asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory"); +} + +TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + uint64_t state = 0; + asm volatile("{\n" + ".reg .pred P1;\n" + "mbarrier.arrive.shared.b64 %1, [%0];\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.shared.b64 P1, [%0], %1;\n" + "@!P1 bra.uni LAB_WAIT;\n" + "}\n" + : + : "r"(smem_int_ptr), "l"(state)); +} +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/common.h b/tilelang/original/src/tl_templates/cuda/common.h new file mode 100644 index 0000000000000000000000000000000000000000..bf2a5100bd4d390fd7c91a3f4b601b03b6781f22 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/common.h @@ -0,0 +1,628 @@ +#pragma once + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "atomic.h" +#include +#include +#include +#include + +#include +#include + +using cutlass::bfloat16_t; +using cutlass::half_t; +using cutlass::tfloat32_t; + +using cute::cast_smem_ptr_to_uint; + +using int4_t = int4; + +#define hexp cutlass::fast_exp +#define hlog cutlass::fast_log +#define hsqrt cutlass::fast_sqrt +#define hsin cutlass::fast_sin +#define hcos cutlass::fast_cos +#define htanh cutlass::fast_tanh +#define hpow powf + +#define uint unsigned int +#define uchar unsigned char +#define ushort unsigned short + +#define TL_DEVICE __forceinline__ __device__ +#define TL_DEVICE_NOINLINE __noinline__ __device__ +#define TL_PATCH + +#define TILELANG_CHECK(stmt) \ + do { \ + cudaError_t __err = (stmt); \ + if (__err != cudaSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \ + __LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +#define TILELANG_CHECK_LAST_ERROR(kernel_name) \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, kernel_name ": %s - %s", \ + cudaGetErrorName(__err), cudaGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +// using cutlass abs function for half_t +TL_PATCH TL_DEVICE half_t __habs(const half_t x) { return abs(x); } + +// using cutlass abs function for bfloat_t +TL_PATCH TL_DEVICE bfloat16_t __habs(const bfloat16_t x) { return abs(x); } + +// hrsqrt function for half_t +TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) { + return half_t(hrsqrt(x.to_half())); +} + +// Pack two half values. +TL_DEVICE unsigned __pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two half_t values. +TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack four char values. +TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2, + signed char x3) { + return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0; +} + +// Pack eight char values. +TL_DEVICE int2 make_int2(signed char x0, signed char x1, signed char x2, + signed char x3, signed char y0, signed char y1, + signed char y2, signed char y3) { + int2 result; + result.x = make_int(x0, x1, x2, x3); + result.y = make_int(y0, y1, y2, y3); + return result; +} + +// Pack sixteen char values. +TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2, + signed char x3, signed char y0, signed char y1, + signed char y2, signed char y3, signed char z0, + signed char z1, signed char z2, signed char z3, + signed char w0, signed char w1, signed char w2, + signed char w3) { + int4_t result; + result.x = make_int(x0, x1, x2, x3); + result.y = make_int(y0, y1, y2, y3); + result.z = make_int(z0, z1, z2, z3); + result.w = make_int(w0, w1, w2, w3); + return result; +} + +TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0, + short z1, short w0, short w1) { + int4_t result; + *((short2 *)&result.x) = make_short2(x0, x1); + *((short2 *)&result.y) = make_short2(y0, y1); + *((short2 *)&result.z) = make_short2(z0, z1); + *((short2 *)&result.w) = make_short2(w0, w1); + return result; +} + +// Pack eight int values. +TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, + int z1, int w0, int w1) { + longlong4 result; + *((int2 *)&result.x) = make_int2(x0, x1); + *((int2 *)&result.y) = make_int2(y0, y1); + *((int2 *)&result.z) = make_int2(z0, z1); + *((int2 *)&result.w) = make_int2(w0, w1); + return result; +} + +// Helper to cast SMEM pointer to unsigned +TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); +} + +/** + * Convert a shared-memory pointer to a 32-bit unsigned integer address. + * + * Casts the given pointer (expected to reference shared memory) into a 32-bit + * unsigned integer using the device address-space conversion required for + * shared-memory pointers. + * + * @param smem_ptr Pointer into shared memory. + * @return 32-bit unsigned integer representation of the shared-memory address. + * + * @note The pointer must refer to shared memory; behavior is undefined for + * pointers in other address spaces. + */ +TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) { + unsigned int smem_int; + asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; " + "cvt.u32.u64 %0, smem_int; }" + : "=r"(smem_int) + : "l"(smem_ptr)); + return smem_int; +} + +// DP4A +template +TL_DEVICE /** + * Compute a 4×8-bit dot-product-accumulate using the CUDA DP4A + * intrinsic. + * + * Reads 32-bit packed values from `a` and `b` (each containing four + * signed 8-bit lanes), applies the __dp4a operation (dot product of + * the four lane pairs added to an accumulator), and stores the 32-bit + * integer result through `c`. + * + * @param a Pointer to a 32-bit packed input containing four signed + * 8-bit elements. + * @param b Pointer to a 32-bit packed input containing four signed + * 8-bit elements. + * @param c Pointer to a 32-bit accumulator; its current value is used + * as the initial accumulator and overwritten with the resulting int32 + * sum. + */ + void + DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) { + const int a_int = *((int *)a); + const int b_int = *((int *)b); + const int c_int = *((int *)c); + *c = __dp4a(a_int, b_int, c_int); +} + +namespace tl { +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 +}; + +union GmmaDescriptor { + CUTE_HOST_DEVICE constexpr GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(uint64_t desc) noexcept + : desc_(desc) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept + : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept + : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr GmmaDescriptor & + operator=(GmmaDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr GmmaDescriptor & + operator=(GmmaDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 + // brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, + : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { + return desc_; + } + template + CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { + GmmaDescriptor ret; + ret.reg32_[0] = reg32_[0] + uint32_t(offset); + ret.reg32_[1] = reg32_[1]; + return ret; + } +}; + +union Tcgen05SMemDescriptor { + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(uint64_t desc) noexcept + : desc_(desc) {} + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor( + Tcgen05SMemDescriptor const &t) noexcept + : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor( + Tcgen05SMemDescriptor &&t) noexcept + : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor & + operator=(Tcgen05SMemDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor & + operator=(Tcgen05SMemDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + uint16_t stride_byte_offset_ : 14, + version_ : 2; // 14 bits [0,14), 2 bits [14,16) + // base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53). + uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1, + : 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused + // layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0, + // SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4, + // SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5, + // N/A = 7 + uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8) + } bitfield; + // Separate the field, as we may only update one part of desc + struct { + uint32_t lo; + uint32_t hi; + } words; + + CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { + return desc_; + } + template + CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor + operator+(const T &offset) const { + Tcgen05SMemDescriptor ret; + // Address addition is in units of 16 bytes (4 LSB not encoded) + ret.reg32_[0] = reg32_[0] + (uint32_t(offset) >> 4); + ret.reg32_[1] = reg32_[1]; + return ret; + } +}; + +// +// Tcgen05 instruction descriptor (wraps cute::UMMA::InstrDescriptor layout) +// +union Tcgen05InstrDescriptor { + CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(uint32_t desc) noexcept + : desc_(desc) {} + CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor( + Tcgen05InstrDescriptor const &t) noexcept + : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor( + Tcgen05InstrDescriptor &&t) noexcept + : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor & + operator=(Tcgen05InstrDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor & + operator=(Tcgen05InstrDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint32_t desc_; + uint16_t reg16_[2]; + + // Bitfield implementation mirrors cute::UMMA::InstrDescriptor + struct { + // bit [ 0, 2) : Sparse meta data id2 + uint16_t sparse_id2_ : 2, + // bit [ 2, 3) : 0 = dense. 1 = sparse. Only valid for + // F32F16/S8/MXF8F6F4 + sparse_flag_ : 1, + // bit [ 3, 4) : 0 = no saturate. 1 = saturate. Only valid for S8 + saturate_ : 1, + // bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32 + c_format_ : 2, + // padding + : 1, + // bit [ 7,10) : see UMMA format encoding + a_format_ : 3, + // bit [10,13) : see UMMA format encoding + b_format_ : 3, + // bit [13,14) : 0 = no negate. 1 = negate + a_negate_ : 1, + // bit [14,15) : 0 = no negate. 1 = negate + b_negate_ : 1, + // bit [15,16) : 0 = K-major. 1 = MN-major + a_major_ : 1; + + // Upper 16 bits + uint16_t b_major_ : 1, // bit [16,17) + n_dim_ : 6, // bit [17,23) : 3 LSBs not included + : 1, // padding + m_dim_ : 5, // bit [24,29) : 4 LSBs not included + : 1, // padding + max_shift_ : 2; // bit [30,32) + } bitfield; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr explicit operator uint32_t() const noexcept { + return desc_; + } +}; + +// Any +template TL_DEVICE bool Any(T *a, int size) { + for (int i = 0; i < size; i++) { + if (a[i]) { + return true; + } + } + return false; +} + +// All +template TL_DEVICE bool All(T *a, int size) { + for (int i = 0; i < size; i++) { + if (!a[i]) { + return false; + } + } + return true; +} + +// Pow of int +template TL_DEVICE T pow_of_int(T x) { + T result = x; + for (int i = 1; i < y; i++) { + result *= x; + } + return result; +} + +// Thread partial barrier synchronization +// https://docs.nvidia.com/cuda/parallel-thread-execution/#memory-consistency-model +template +TL_DEVICE void __sync_thread_partial() { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count)); +} + +template +TL_DEVICE void initialize_wgmma_descriptor(GmmaDescriptor &descriptor, + T *start_address) { + descriptor.bitfield.start_address_ = + cute::cast_smem_ptr_to_uint(start_address) >> 4; + descriptor.bitfield.layout_type_ = layout_type; + descriptor.bitfield.base_offset_ = 0; + descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; + descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; +} + +template +TL_DEVICE void +initialize_tcgen05_descriptor(Tcgen05SMemDescriptor &descriptor, + T *start_address, int leading_byte_offset, + int stride_byte_offset, int base_offset, + bool leading_is_absolute, int swizzle_mode) { + + descriptor.bitfield.start_address_ = + static_cast(cast_smem_ptr_to_uint(start_address) >> 4); + descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; + descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; + descriptor.bitfield.version_ = 1; + descriptor.bitfield.base_offset_ = base_offset & 0x7; + descriptor.bitfield.lbo_mode_ = leading_is_absolute ? 1 : 0; + descriptor.bitfield.layout_type_ = swizzle_mode & 0x7; +} + +template +TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, + T offset) { + descriptor.reg32_[0] += (offset >> 4); +} + +// and add the desired implicit conversion from bfloat16_t. +struct float_e4m3_t : public cute::float_e4m3_t { + using cute::float_e4m3_t::float_e4m3_t; + CUTLASS_HOST_DEVICE + float_e4m3_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(__nv_bfloat16 x) + : float_e4m3_t(static_cast(x)) {} +}; + +struct float_e5m2_t : public cute::float_e5m2_t { + using cute::float_e5m2_t::float_e5m2_t; + CUTLASS_HOST_DEVICE + float_e5m2_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(__nv_bfloat16 x) + : float_e5m2_t(static_cast(x)) {} +}; + +template struct to_cute_type { + using type = T; +}; +template <> struct to_cute_type { + using type = cute::float_e4m3_t; +}; +template <> struct to_cute_type { + using type = cute::float_e5m2_t; +}; + +} // namespace tl + +namespace cutlass { +TL_DEVICE +bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); } +} // namespace cutlass + +// +// Type-safe warp shuffle helpers for 16-bit float types +// These wrappers avoid relying on implicit conversions that may be disallowed +// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to +// float for the shuffle and then down-converting. +// +namespace tl { + +// Generic passthroughs +template +TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) { + return __shfl_xor_sync(mask, val, laneMask); +} + +template +TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) { + return __shfl_down_sync(mask, val, delta); +} + +template +TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) { + return __shfl_up_sync(mask, val, delta); +} + +template TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) { + return __shfl_sync(mask, val, srcLane); +} + +// Specializations for cutlass::half_t +template <> +TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) { + float f = static_cast(val); + float r = __shfl_xor_sync(mask, f, laneMask); + return half_t(r); +} + +template <> +TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) { + float f = static_cast(val); + float r = __shfl_down_sync(mask, f, delta); + return half_t(r); +} + +template <> +TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) { + float f = static_cast(val); + float r = __shfl_up_sync(mask, f, delta); + return half_t(r); +} + +template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) { + float f = static_cast(val); + float r = __shfl_sync(mask, f, srcLane); + return half_t(r); +} + +// Specializations for cutlass::bfloat16_t +template <> +TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val, + int laneMask) { + float f = static_cast(val); + float r = __shfl_xor_sync(mask, f, laneMask); + return bfloat16_t(r); +} + +template <> +TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) { + float f = static_cast(val); + float r = __shfl_down_sync(mask, f, delta); + return bfloat16_t(r); +} + +template <> +TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) { + float f = static_cast(val); + float r = __shfl_up_sync(mask, f, delta); + return bfloat16_t(r); +} + +template <> +TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) { + float f = static_cast(val); + float r = __shfl_sync(mask, f, srcLane); + return bfloat16_t(r); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/compress_sm90.cu b/tilelang/original/src/tl_templates/cuda/compress_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8bb236dd8374c5cc2d41a48b5acc16c503fd7fb4 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/compress_sm90.cu @@ -0,0 +1,167 @@ +#include + +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } +template +std::tuple compress_impl(torch::Tensor A) { + using ElementA = T; + using ElementE = uint8_t; + using LayoutTagA = conditional_t; + using ProblemShape = cute::Shape; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + // NOTE: this is derived from sparse sm90 mma atoms + // Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp + using SparseE = conditional_t<(sizeof_bits_v == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>; + static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, ElementA>, GmmaMajorA, + SparseE, cute::C>; + + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + TORCH_CHECK(A.is_contiguous(), "A need to be contiguous"); + TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future "); + + int M = -1; + int K = -1; + int N = -1; // not used, but required for config + int L = 1; + if constexpr(transposed) { + M = A.size(1); + K = A.size(0); + } else { + M = A.size(0); + K = A.size(1); + } + + ProblemShape problem_shape = make_tuple(M, N, K, L); + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + auto dtype = A.dtype().toScalarType(); + torch::Tensor A_compressed = torch::zeros(KC * M, + torch::TensorOptions().dtype(dtype).device(A.device())); + torch::Tensor E = torch::zeros({ME, KE}, + torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = A.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Compressor::Arguments arguments{problem_shape, + { + A.data_ptr(), + stride_A, + A_compressed.data_ptr(), + E.data_ptr(), + }, + {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + if constexpr (transposed) { + return std::make_tuple(A_compressed.view({KC, M}), E); + } else { + return std::make_tuple(A_compressed.view({M, KC}), E); + } +} + +// block <= 128 +// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 +#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (BLOCK_K) { \ + case int(32 * FACTOR): return compress_impl(TENSOR); \ + case int(64 * FACTOR): return compress_impl(TENSOR); \ + case int(128 * FACTOR): return compress_impl(TENSOR); \ + default: \ + TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ + } \ + }() + +#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (dtype) { \ + case torch::kFloat32: \ + return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ + case torch::kFloat16: \ + case torch::kBFloat16: \ + return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ + case torch::kFloat8_e4m3fn: \ + return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ + case torch::kFloat8_e5m2: \ + return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ + case torch::kChar: \ + return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ + case torch::kByte: \ + return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ + default: \ + TORCH_CHECK(false, "Unsupported dtype"); \ + } \ + }() + +std::tuple compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { + auto dtype = A.dtype().toScalarType(); + return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false); +} + +#undef DISPATCH_BLOCK_K +#undef DISPATCH_CONTIGUOUS + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compress_sm90", torch::wrap_pybind_function(compress_sm90), + "compress_sm90"); +} diff --git a/tilelang/original/src/tl_templates/cuda/copy.h b/tilelang/original/src/tl_templates/cuda/copy.h new file mode 100644 index 0000000000000000000000000000000000000000..0fa7b9d91e30f3299cdf675c7eaaa55d5b1594c2 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/copy.h @@ -0,0 +1,81 @@ +#pragma once + +#include "common.h" + +#ifdef __CUDA_ARCH_LIST__ +#if __CUDA_ARCH_LIST__ >= 900 +#include "copy_sm90.h" +#endif +#if __CUDA_ARCH_LIST__ >= 1000 +#include "copy_sm100.h" +#endif +#endif + +namespace tl { + +TL_DEVICE void cp_async_commit() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template TL_DEVICE void cp_async_wait() { + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + } +} + +template +TL_DEVICE void cp_async_gs(void const *const smem_addr, + void const *global_ptr) { + static_assert(N == 16 || N == 8 || N == 4); + unsigned int addr = smem_ptr_to_uint(smem_addr); + if constexpr (N == 16) { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" +#else + "cp.async.cg.shared.global [%0], [%1], %2;" +#endif + ::"r"(addr), + "l"((void const *)(global_ptr)), "n"(N)); + } else { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" +#else + "cp.async.ca.shared.global [%0], [%1], %2;" +#endif + ::"r"(addr), + "l"((void const *)(global_ptr)), "n"(N)); + } +} + +template +TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr, + void const *global_ptr, bool cond) { + static_assert(N == 16 || N == 8 || N == 4); + int bytes = cond ? N : 0; + unsigned int addr = smem_ptr_to_uint(smem_addr); + if constexpr (N == 16) { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;" +#else + "cp.async.cg.shared.global [%0], [%1], %2, %3;" +#endif + ::"r"(addr), + "l"((void const *)(global_ptr)), "n"(N), "r"(bytes)); + } else { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;" +#else + "cp.async.ca.shared.global [%0], [%1], %2, %3;" +#endif + ::"r"(addr), + "l"((void const *)(global_ptr)), "n"(N), "r"(bytes)); + } +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/copy_sm100.h b/tilelang/original/src/tl_templates/cuda/copy_sm100.h new file mode 100644 index 0000000000000000000000000000000000000000..82d0cca260e225dc6a1128cac1edeac9435d5293 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/copy_sm100.h @@ -0,0 +1,158 @@ +#pragma once +#include "cuda_fp8.h" +#include "tcgen_05.h" +#include "tcgen_05_ld.h" + +namespace tl { + +// 256-bit load for longlong4 +__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) { + longlong4 ret; + asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +// 256-bit load for ulonglong4 +__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +// Generic 256-bit load for FP8 types (returns ulonglong4) +template +__device__ __forceinline__ ulonglong4 ld_global_256(const T *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +// 256-bit store for longlong4 +__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) { + asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +// 256-bit store for ulonglong4 with non-const reference +__device__ __forceinline__ void st_global_256(ulonglong4 *ptr, + ulonglong4 &val) { + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +// 256-bit store for ulonglong4 with const reference +// must be const &val, otherwise the compiler will generate a temporary variable +// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr)) +__device__ __forceinline__ void st_global_256(ulonglong4 *ptr, + const ulonglong4 &val) { + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +// Generic 256-bit store for FP8 types +template +__device__ __forceinline__ void st_global_256(T *ptr, const ulonglong4 &val) { + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +// Generic 256-bit store for FP8 types with non-const reference +template +__device__ __forceinline__ void st_global_256(T *ptr, T &val) { + ulonglong4 &val_u64 = *((ulonglong4 *)&val); + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val_u64.x), "l"(val_u64.y), "l"(val_u64.z), + "l"(val_u64.w)); +} + +__device__ __forceinline__ unsigned long long +pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, + const bfloat16_t w) { + unsigned long long v0 = *((unsigned short *)&x); + unsigned long long v1 = *((unsigned short *)&y); + unsigned long long v2 = *((unsigned short *)&z); + unsigned long long v3 = *((unsigned short *)&w); + return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48)); +} + +__device__ __forceinline__ unsigned long long +pack_float16x4(const half x, const half y, const half z, const half w) { + unsigned long long v0 = *((unsigned short *)&x); + unsigned long long v1 = *((unsigned short *)&y); + unsigned long long v2 = *((unsigned short *)&z); + unsigned long long v3 = *((unsigned short *)&w); + return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48)); +} + +// Helper function to find the largest K that 2**K <= N +// Requires N > 0 +template +__device__ __forceinline__ constexpr int get_floor_log2() { + static_assert(N > 0); + if constexpr ((1 << (K + 1)) > N) + return K; + else + return get_floor_log2(); +} + +template +__device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, + dst_t *dst_ptr) { + static_assert(N > 0); + constexpr int LOG_N = get_floor_log2(); + constexpr int CUR_SEGMENT_LEN = 1 << (LOG_N > MAX_LOGN ? MAX_LOGN : LOG_N); + target_call_cls::copy(tmem_start_col, (uint32_t *)dst_ptr); + if constexpr (N - CUR_SEGMENT_LEN > 0) { + tcgen05_ld_core( + tmem_start_col + CUR_SEGMENT_LEN, dst_ptr + CUR_SEGMENT_LEN); + } +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core, 6, N>( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core, 5, N>( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/copy_sm90.h b/tilelang/original/src/tl_templates/cuda/copy_sm90.h new file mode 100644 index 0000000000000000000000000000000000000000..0b51450b3148154de2c65916f25b65fabd091f97 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/copy_sm90.h @@ -0,0 +1,270 @@ +#pragma once + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "barrier.h" +#include "common.h" + +namespace tl { +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + +template +TL_DEVICE void tma_load(void *smem_ptr, void const *gmem_ptr, + BarrierType &smem_mbar, uint32_t size) { + uint32_t smem_int_mbar = + smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" + "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), + "l"((void const *)gmem_ptr), "r"(size), "r"(smem_int_mbar) + :); +} + +TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, + uint64_t &smem_mbar, uint32_t size, + uint16_t mask) { + uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes." + "multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask) + :); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +} +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), + "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &coord_c, + int32_t const &coord_w, int32_t const &coord_h, + int32_t const &coord_n, uint16_t const &offset_w, + uint16_t const &offset_h) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar = + smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:" + ":complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group" + ".L2::cache_hint [%0], [%1], %2, %3;" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint) + :); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2}], [%1], %3;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), + "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3}], [%1], %4;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3, %4}], [%1], %5;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "r"(crd2), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +} + +TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr, + int32_t const &store_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 " + "[%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +} + +TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh b/tilelang/original/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f5641f61609172090da1c8e77e43f9f4694ccca0 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh @@ -0,0 +1,257 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include + +namespace fastertransformer { + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x);; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; t.x = x; t.y = y; return t; +} + +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace fastertransformer diff --git a/tilelang/original/src/tl_templates/cuda/cuda_bf16_wrapper.h b/tilelang/original/src/tl_templates/cuda/cuda_bf16_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..efb6e798730879bc2cd16088b2091991862a6074 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/cuda_bf16_wrapper.h @@ -0,0 +1,23 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/tilelang/original/src/tl_templates/cuda/cuda_fp4.h b/tilelang/original/src/tl_templates/cuda/cuda_fp4.h new file mode 100644 index 0000000000000000000000000000000000000000..e3f56622f432f36bd6f028adf8b0fa5dbdb55c68 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/cuda_fp4.h @@ -0,0 +1,157 @@ +#pragma once + +#include "common.h" + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#include + +// Wrapper for __nv_fp4_e2m1 with implicit conversions +struct fp4_e2_t { + __nv_fp4_storage_t __x; + + TL_DEVICE fp4_e2_t() = default; + + // Constructor from __nv_fp4_e2m1 + TL_DEVICE fp4_e2_t(__nv_fp4_e2m1 x) : __x(x.__x) {} + + // Constructor from storage type + TL_DEVICE fp4_e2_t(__nv_fp4_storage_t x) : __x(x) {} + + // Constructor from float + TL_DEVICE explicit fp4_e2_t(float x) { + __nv_fp4_e2m1 tmp(x); + __x = tmp.__x; + } + + // Conversion to __nv_fp4_e2m1 + TL_DEVICE operator __nv_fp4_e2m1() const { + __nv_fp4_e2m1 tmp; + tmp.__x = __x; + return tmp; + } + + // Conversion to float + TL_DEVICE operator float() const { + __nv_fp4_e2m1 tmp; + tmp.__x = __x; + return float(tmp); + } + + // Implicit conversion to half_t (cutlass::half_t) + TL_DEVICE operator half_t() const { return half_t(float(*this)); } + + // Implicit conversion to __half + TL_DEVICE operator __half() const { return __half(float(*this)); } +}; + +using fp4_e2x2_t = __nv_fp4x2_e2m1; +using fp4_e2x4_t = __nv_fp4x4_e2m1; + +struct fp4_e2x8_t { + fp4_e2_t data[8]; +}; + +struct fp4_e2x16_t { + fp4_e2_t data[16]; +}; + +struct __CUDA_ALIGN__(1) fp4_e2_2_t { + fp4_e2_t x; + fp4_e2_t y; +}; + +struct __CUDA_ALIGN__(2) fp4_e2_4_t { + fp4_e2_t x; + fp4_e2_t y; + fp4_e2_t z; + fp4_e2_t w; +}; + +struct __CUDA_ALIGN__(4) fp4_e2_8_t { + fp4_e2_4_t x; + fp4_e2_4_t y; +}; + +struct __CUDA_ALIGN__(8) fp4_e2_16_t { + fp4_e2_8_t x; + fp4_e2_8_t y; +}; + +struct __CUDA_ALIGN__(16) fp4_e2_32_t { + fp4_e2_16_t x; + fp4_e2_16_t y; + + TL_DEVICE fp4_e2_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp4_e2_8_t *)&rhs.x; + x.y = *(fp4_e2_8_t *)&rhs.y; + y.x = *(fp4_e2_8_t *)&rhs.z; + y.y = *(fp4_e2_8_t *)&rhs.w; + return *this; + } +}; + +struct __CUDA_ALIGN__(32) fp4_e2_64_t { + fp4_e2_32_t x; + fp4_e2_32_t y; +}; + +// Pack two fp4_e2_t values. +TL_DEVICE fp4_e2_2_t make_fp4_e2_2_t(fp4_e2_t x, fp4_e2_t y) { + fp4_e2_2_t result; + result.x = x; + result.y = y; + return result; +} + +// Pack four fp4_e2_t values. +TL_DEVICE fp4_e2_4_t make_fp4_e2_4_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, + fp4_e2_t x3) { + fp4_e2_4_t result; + result.x = x0; + result.y = x1; + result.z = x2; + result.w = x3; + return result; +} + +// Pack eight fp4_e2_t values. +TL_DEVICE fp4_e2_8_t make_fp4_e2_8_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, + fp4_e2_t x3, fp4_e2_t x4, fp4_e2_t x5, + fp4_e2_t x6, fp4_e2_t x7) { + fp4_e2_8_t result; + result.x = make_fp4_e2_4_t(x0, x1, x2, x3); + result.y = make_fp4_e2_4_t(x4, x5, x6, x7); + return result; +} + +// Pack sixteen fp4_e2_t values. +TL_DEVICE fp4_e2_16_t make_fp4_e2_16_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, + fp4_e2_t x3, fp4_e2_t x4, fp4_e2_t x5, + fp4_e2_t x6, fp4_e2_t x7, fp4_e2_t y0, + fp4_e2_t y1, fp4_e2_t y2, fp4_e2_t y3, + fp4_e2_t y4, fp4_e2_t y5, fp4_e2_t y6, + fp4_e2_t y7) { + fp4_e2_16_t result; + result.x = make_fp4_e2_8_t(x0, x1, x2, x3, x4, x5, x6, x7); + result.y = make_fp4_e2_8_t(y0, y1, y2, y3, y4, y5, y6, y7); + return result; +} + +// Pack thirty-two fp4_e2_t values. +TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t( + fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, fp4_e2_t x3, fp4_e2_t x4, + fp4_e2_t x5, fp4_e2_t x6, fp4_e2_t x7, fp4_e2_t x8, fp4_e2_t x9, + fp4_e2_t x10, fp4_e2_t x11, fp4_e2_t x12, fp4_e2_t x13, fp4_e2_t x14, + fp4_e2_t x15, fp4_e2_t y0, fp4_e2_t y1, fp4_e2_t y2, fp4_e2_t y3, + fp4_e2_t y4, fp4_e2_t y5, fp4_e2_t y6, fp4_e2_t y7, fp4_e2_t y8, + fp4_e2_t y9, fp4_e2_t y10, fp4_e2_t y11, fp4_e2_t y12, fp4_e2_t y13, + fp4_e2_t y14, fp4_e2_t y15) { + fp4_e2_32_t result; + result.x = make_fp4_e2_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, + x12, x13, x14, x15); + result.y = make_fp4_e2_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, + y12, y13, y14, y15); + return result; +} + +#endif diff --git a/tilelang/original/src/tl_templates/cuda/cuda_fp8.h b/tilelang/original/src/tl_templates/cuda/cuda_fp8.h new file mode 100644 index 0000000000000000000000000000000000000000..d1774d47d31d7e7882f0510de04d7f114c35589b --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/cuda_fp8.h @@ -0,0 +1,302 @@ +#pragma once + +#include "common.h" +#include +#include + +using fp8_e4_t = tl::float_e4m3_t; +using fp8_e5_t = tl::float_e5m2_t; +using fp8_e8_t = __nv_fp8_e8m0; + +struct __CUDA_ALIGN__(2) fp8_e4_2_t { + fp8_e4_t x; + fp8_e4_t y; +}; + +struct __CUDA_ALIGN__(4) fp8_e4_4_t { + fp8_e4_t x; + fp8_e4_t y; + fp8_e4_t z; + fp8_e4_t w; +}; + +struct __CUDA_ALIGN__(8) fp8_e4_8_t { + fp8_e4_4_t x; + fp8_e4_4_t y; +}; + +struct __CUDA_ALIGN__(16) fp8_e4_16_t { + fp8_e4_8_t x; + fp8_e4_8_t y; +}; + +struct __CUDA_ALIGN__(32) fp8_e4_32_t { + fp8_e4_16_t x; + fp8_e4_16_t y; + + TL_DEVICE fp8_e4_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e4_8_t *)&rhs.x; + x.y = *(fp8_e4_8_t *)&rhs.y; + y.x = *(fp8_e4_8_t *)&rhs.z; + y.y = *(fp8_e4_8_t *)&rhs.w; + return *this; + } +}; + +struct __CUDA_ALIGN__(2) fp8_e5_2_t { + fp8_e5_t x; + fp8_e5_t y; +}; + +struct __CUDA_ALIGN__(4) fp8_e5_4_t { + fp8_e5_t x; + fp8_e5_t y; + fp8_e5_t z; + fp8_e5_t w; +}; + +struct __CUDA_ALIGN__(8) fp8_e5_8_t { + fp8_e5_4_t x; + fp8_e5_4_t y; +}; + +struct __CUDA_ALIGN__(16) fp8_e5_16_t { + fp8_e5_8_t x; + fp8_e5_8_t y; +}; + +struct __CUDA_ALIGN__(32) fp8_e5_32_t { + fp8_e5_16_t x; + fp8_e5_16_t y; + + TL_DEVICE fp8_e5_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e5_8_t *)&rhs.x; + x.y = *(fp8_e5_8_t *)&rhs.y; + y.x = *(fp8_e5_8_t *)&rhs.z; + y.y = *(fp8_e5_8_t *)&rhs.w; + return *this; + } +}; + +struct __CUDA_ALIGN__(2) fp8_e8_2_t { + fp8_e8_t x; + fp8_e8_t y; +}; + +struct __CUDA_ALIGN__(4) fp8_e8_4_t { + fp8_e8_t x; + fp8_e8_t y; + fp8_e8_t z; + fp8_e8_t w; +}; + +struct __CUDA_ALIGN__(8) fp8_e8_8_t { + fp8_e8_4_t x; + fp8_e8_4_t y; +}; + +struct __CUDA_ALIGN__(16) fp8_e8_16_t { + fp8_e8_8_t x; + fp8_e8_8_t y; +}; + +struct __CUDA_ALIGN__(32) fp8_e8_32_t { + fp8_e8_16_t x; + fp8_e8_16_t y; + + TL_DEVICE fp8_e8_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e8_8_t *)&rhs.x; + x.y = *(fp8_e8_8_t *)&rhs.y; + y.x = *(fp8_e8_8_t *)&rhs.z; + y.y = *(fp8_e8_8_t *)&rhs.w; + return *this; + } +}; + +// Pack two fp8_e4_t values. +TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) { + fp8_e4_2_t result; + result.x = x; + result.y = y; + return result; +} + +// Pack four fp8_e4_t values. +TL_DEVICE fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3) { + fp8_e4_4_t result; + result.x = x0; + result.y = x1; + result.z = x2; + result.w = x3; + return result; +} + +// Pack eight fp8_e4_t values. +TL_DEVICE fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, fp8_e4_t x7) { + fp8_e4_8_t result; + result.x = make_fp8_e4_4_t(x0, x1, x2, x3); + result.y = make_fp8_e4_4_t(x4, x5, x6, x7); + return result; +} + +// Pack sixteen fp8_e4_t values. +TL_DEVICE fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0, + fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, + fp8_e4_t y7) { + fp8_e4_16_t result; + result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7); + result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7); + return result; +} + +// Pack thirty-two fp8_e4_t values. +TL_DEVICE fp8_e4_32_t make_fp8_e4_32_t( + fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, fp8_e4_t x4, + fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t x8, fp8_e4_t x9, + fp8_e4_t x10, fp8_e4_t x11, fp8_e4_t x12, fp8_e4_t x13, fp8_e4_t x14, + fp8_e4_t x15, fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7, fp8_e4_t y8, + fp8_e4_t y9, fp8_e4_t y10, fp8_e4_t y11, fp8_e4_t y12, fp8_e4_t y13, + fp8_e4_t y14, fp8_e4_t y15) { + fp8_e4_32_t result; + result.x = make_fp8_e4_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, + x12, x13, x14, x15); + result.y = make_fp8_e4_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, + y12, y13, y14, y15); + return result; +} + +// Pack two fp8_e5_t values. +TL_DEVICE fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) { + fp8_e5_2_t result; + result.x = x; + result.y = y; + return result; +} + +// Pack four fp8_e5_t values. +TL_DEVICE fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, + fp8_e5_t x3) { + fp8_e5_4_t result; + result.x = x0; + result.y = x1; + result.z = x2; + result.w = x3; + return result; +} + +// Pack eight fp8_e5_t values. +TL_DEVICE fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, + fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5, + fp8_e5_t x6, fp8_e5_t x7) { + fp8_e5_8_t result; + result.x = make_fp8_e5_4_t(x0, x1, x2, x3); + result.y = make_fp8_e5_4_t(x4, x5, x6, x7); + return result; +} + +// Pack sixteen fp8_e5_t values. +TL_DEVICE fp8_e5_16_t make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, + fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5, + fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t y0, + fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3, + fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, + fp8_e5_t y7) { + fp8_e5_16_t result; + result.x = make_fp8_e5_8_t(x0, x1, x2, x3, x4, x5, x6, x7); + result.y = make_fp8_e5_8_t(y0, y1, y2, y3, y4, y5, y6, y7); + return result; +} + +// Pack thirty-two fp8_e5_t values. +TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t( + fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, fp8_e5_t x4, + fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t x8, fp8_e5_t x9, + fp8_e5_t x10, fp8_e5_t x11, fp8_e5_t x12, fp8_e5_t x13, fp8_e5_t x14, + fp8_e5_t x15, fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3, + fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7, fp8_e5_t y8, + fp8_e5_t y9, fp8_e5_t y10, fp8_e5_t y11, fp8_e5_t y12, fp8_e5_t y13, + fp8_e5_t y14, fp8_e5_t y15) { + fp8_e5_32_t result; + result.x = make_fp8_e5_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, + x12, x13, x14, x15); + result.y = make_fp8_e5_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, + y12, y13, y14, y15); + return result; +} + +// Pack two fp8_e8_t values. +TL_DEVICE fp8_e8_2_t make_fp8_e8_2_t(fp8_e8_t x, fp8_e8_t y) { + fp8_e8_2_t result; + result.x = x; + result.y = y; + return result; +} + +// Pack four fp8_e8_t values. +TL_DEVICE fp8_e8_4_t make_fp8_e8_4_t(fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2, + fp8_e8_t x3) { + fp8_e8_4_t result; + result.x = x0; + result.y = x1; + result.z = x2; + result.w = x3; + return result; +} + +// Pack eight fp8_e8_t values. +TL_DEVICE fp8_e8_8_t make_fp8_e8_8_t(fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2, + fp8_e8_t x3, fp8_e8_t x4, fp8_e8_t x5, + fp8_e8_t x6, fp8_e8_t x7) { + fp8_e8_8_t result; + result.x = make_fp8_e8_4_t(x0, x1, x2, x3); + result.y = make_fp8_e8_4_t(x4, x5, x6, x7); + return result; +} + +// Pack sixteen fp8_e8_t values. +TL_DEVICE fp8_e8_16_t make_fp8_e8_16_t(fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2, + fp8_e8_t x3, fp8_e8_t x4, fp8_e8_t x5, + fp8_e8_t x6, fp8_e8_t x7, fp8_e8_t y0, + fp8_e8_t y1, fp8_e8_t y2, fp8_e8_t y3, + fp8_e8_t y4, fp8_e8_t y5, fp8_e8_t y6, + fp8_e8_t y7) { + fp8_e8_16_t result; + result.x = make_fp8_e8_8_t(x0, x1, x2, x3, x4, x5, x6, x7); + result.y = make_fp8_e8_8_t(y0, y1, y2, y3, y4, y5, y6, y7); + return result; +} + +// Pack thirty-two fp8_e8_t values. +TL_DEVICE fp8_e8_32_t make_fp8_e8_32_t( + fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2, fp8_e8_t x3, fp8_e8_t x4, + fp8_e8_t x5, fp8_e8_t x6, fp8_e8_t x7, fp8_e8_t x8, fp8_e8_t x9, + fp8_e8_t x10, fp8_e8_t x11, fp8_e8_t x12, fp8_e8_t x13, fp8_e8_t x14, + fp8_e8_t x15, fp8_e8_t y0, fp8_e8_t y1, fp8_e8_t y2, fp8_e8_t y3, + fp8_e8_t y4, fp8_e8_t y5, fp8_e8_t y6, fp8_e8_t y7, fp8_e8_t y8, + fp8_e8_t y9, fp8_e8_t y10, fp8_e8_t y11, fp8_e8_t y12, fp8_e8_t y13, + fp8_e8_t y14, fp8_e8_t y15) { + fp8_e8_32_t result; + result.x = make_fp8_e8_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, + x12, x13, x14, x15); + result.y = make_fp8_e8_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, + y12, y13, y14, y15); + return result; +} + +// e4m3x2 -> float2 +TL_DEVICE float2 +__tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x, + const __nv_fp8_interpretation_t fp8_interpretation) { + half2 tmp = __nv_cvt_fp8x2_to_halfraw2(x, fp8_interpretation); + float2 result; + result.x = (float)tmp.x; + result.y = (float)tmp.y; + return result; +} diff --git a/tilelang/original/src/tl_templates/cuda/debug.h b/tilelang/original/src/tl_templates/cuda/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..3f8ce5e6bc1373445a2a6d26b7c77aa669ddebe7 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/debug.h @@ -0,0 +1,128 @@ +#pragma once + +#if __CUDA_ARCH_LIST__ >= 890 +#include "./cuda_fp8.h" +#endif + +#include "common.h" +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +template struct PrintTraits { + static __device__ void print_var(const char *msg, T val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unknown value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (const void *)&val); + } + + static __device__ void print_buffer(const char *msg, const char *buf_name, + int index, T val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=unknown value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (const void *)&val); + } +}; + +#define DEFINE_PRINT_TRAIT(TYPE, NAME, FORMAT, CAST_TYPE) \ + template <> struct PrintTraits { \ + static __device__ void print_var(const char *msg, TYPE val) { \ + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \ + "dtype=" NAME " value=" FORMAT "\n", \ + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \ + threadIdx.y, threadIdx.z, (CAST_TYPE)val); \ + } \ + static __device__ void print_buffer(const char *msg, const char *buf_name, \ + int index, TYPE val) { \ + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \ + "buffer=%s, index=%d, dtype=" NAME " value=" FORMAT "\n", \ + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \ + threadIdx.y, threadIdx.z, buf_name, index, (CAST_TYPE)val); \ + } \ + } + +DEFINE_PRINT_TRAIT(char, "char", "%d", int); +DEFINE_PRINT_TRAIT(signed char, "signed char", "%d", int); +DEFINE_PRINT_TRAIT(unsigned char, "unsigned char", "%u", unsigned int); +DEFINE_PRINT_TRAIT(short, "short", "%d", int); +DEFINE_PRINT_TRAIT(unsigned short, "unsigned short", "%u", unsigned int); +DEFINE_PRINT_TRAIT(int, "int", "%d", int); +DEFINE_PRINT_TRAIT(unsigned int, "uint", "%u", unsigned int); +DEFINE_PRINT_TRAIT(long, "long", "%ld", long); +DEFINE_PRINT_TRAIT(unsigned long, "ulong", "%lu", unsigned long); +DEFINE_PRINT_TRAIT(long long, "long long", "%lld", long long); + +DEFINE_PRINT_TRAIT(float, "float", "%f", float); +DEFINE_PRINT_TRAIT(double, "double", "%lf", double); +DEFINE_PRINT_TRAIT(half, "half", "%f", float); +DEFINE_PRINT_TRAIT(half_t, "half_t", "%f", float); +DEFINE_PRINT_TRAIT(bfloat16_t, "bfloat16_t", "%f", float); + +#if __CUDA_ARCH_LIST__ >= 890 +DEFINE_PRINT_TRAIT(fp8_e4_t, "fp8_e4_t", "%f", float); +DEFINE_PRINT_TRAIT(fp8_e5_t, "fp8_e5_t", "%f", float); +#endif + +template <> struct PrintTraits { + static __device__ void print_var(const char *msg, bool val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool " + "value=%s\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, val ? "true" : "false"); + } + static __device__ void print_buffer(const char *msg, const char *buf_name, + int index, bool val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=bool value=%s\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, val ? "true" : "false"); + } +}; + +template struct PrintTraits { + static __device__ void print_var(const char *msg, T *val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=pointer value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (void *)val); + } + static __device__ void print_buffer(const char *msg, const char *buf_name, + int index, T *val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=pointer value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (void *)val); + } +}; + +template __device__ void debug_print_var(const char *msg, T var) { + PrintTraits::print_var(msg, var); +} + +template +__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, + int index, T var) { + PrintTraits::print_buffer(msg, buf_name, index, var); +} + +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, uint16_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=uint16_t value=%u\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (uint32_t)var); +} + +TL_DEVICE void device_assert(bool cond) { assert(cond); } + +TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { + if (!cond) { + printf("Device assert failed: %s\n", msg); + assert(0); + } +} \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/cuda/gemm.h b/tilelang/original/src/tl_templates/cuda/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..b0b2a1b42e0275f474345b8b5bbf6d9838a59a96 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm.h @@ -0,0 +1,18 @@ +#pragma once + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) +#include "gemm_sm120.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000)) +#include "gemm_sm100.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#include "./instruction/wgmma.h" +#include "gemm_sm90.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) +#include "gemm_sm89.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) +#include "gemm_sm80.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700)) +#include "gemm_sm70.h" +#else +// No matching architecture found +#endif diff --git a/tilelang/original/src/tl_templates/cuda/gemm_mma.h b/tilelang/original/src/tl_templates/cuda/gemm_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..25841a3b6d40801ddbdefbc9f5a9178ec2f9d648 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_mma.h @@ -0,0 +1,483 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "intrin.h" + +namespace cute::tl_mma { + +template +struct DispatchInstruction; + +using _X = Underscore; + +} // namespace cute::tl_mma + +#define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \ + namespace cute::tl_mma { \ + template \ + struct DispatchInstruction { \ + using MMA = MMA_Atom; \ + using MMA_Group = Tile<_X, Int, _X>; \ + }; \ + } +#define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \ + namespace cute::tl_mma { \ + template \ + struct DispatchInstruction { \ + using MMA = MMA_Atom>; \ + using MMA_Group = Tile<_X, Int, _X>; \ + }; \ + } + +#ifdef __CUDA_ARCH_LIST__ +#if __CUDA_ARCH_LIST__ >= 1200 +#include "cuda_fp8.h" +#include +#include +TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 1000 +#include "cuda_fp8.h" +#include +#include +#include +TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 900 +#include "cuda_fp8.h" +#include +#include +TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 890 +#include "cuda_fp8.h" +#include +#include +TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 800 +#include +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 750 +TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN) +#endif +#endif +#undef TL_DISPATCH_MMA +#undef TL_DISPATCH_MMA_TEMPLATE + +namespace cute::tl_mma { + +template struct SelectCopy { + static constexpr int remainder = (N / num_warp_n) % 16; + using type = std::conditional_t< + remainder == 4 || remainder == 8 || remainder == 0, + std::conditional_t< + transpose, + std::conditional_t< + remainder == 4, SM75_U32x1_LDSM_N, + std::conditional_t>, + std::conditional_t< + remainder == 4, SM75_U16x2_LDSM_T, + std::conditional_t>>, + DefaultCopy>; +}; + +template +struct OperandTraits { + // Primary template, use padded layout and default copy + static constexpr int stride = leading_dim; + static constexpr int padded = + stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; + using Layout = typename std::conditional< + K_inner, Layout, Int>, Shape, _1>>, + Layout, Int>, Shape<_1, Int>>>::type; + using Copy = DefaultCopy; +}; + +template +struct OperandTraits<16, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 3, 3>{}, Layout, Stride<_32, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<16, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<16, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<16, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<32, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<32, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 2, 3>{}, Layout, Stride<_16, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<32, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 2, 3>{}, Layout, Stride<_1, _32>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = UniversalCopy; +}; + +template +struct OperandTraits<32, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 2, 3>{}, Layout, Stride<_1, _16>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = UniversalCopy; +}; + +template +struct OperandTraits<8, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<8, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<64, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 0, 4>{}, Layout, Stride<_16, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = DefaultCopy; +}; + +template +struct OperandTraits<64, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 2, 2>{}, Layout, Stride<_1, _16>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = DefaultCopy; +}; + +template +class GemmTensorOp { +public: + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; + using A_type = + typename std::conditional::value, + tfloat32_t, A_type_cute>::type; + using B_type = + typename std::conditional::value, + tfloat32_t, B_type_cute>::type; + using C_type = C_type_raw; + + using Instruction = DispatchInstruction; + + using OperandATraits = OperandTraits::value, M, K, + !trans_A, num_warp_m, lda>; + using OperandBTraits = + OperandTraits::value, N, K, trans_B, num_warp_n, ldb>; + + using SmemLayoutA = typename OperandATraits::Layout; + using SmemLayoutB = typename OperandBTraits::Layout; + using SmemCopyA = Copy_Atom; + using SmemCopyB = Copy_Atom; + + using TileMma = TiledMMA, Int, _1>>, + typename Instruction::MMA_Group>; + + template + static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { + return layout; + } + // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 + // the original layout fail to compile, currently using this as a workaround + template + static CUTE_DEVICE auto + remove_swizzle(ComposedLayout const &layout) { + if constexpr (sizeof(A_type) == 2) + return layout.layout_b(); + else + return layout; + } + + template + static CUTE_DEVICE auto get_region_tensor(Tensor &sa) { + if constexpr (offset == 0) { + return composition( + sa, + Layout, Int>, + Stride<_1, typename std::conditional, + Int>::type>>{}); + } else { + if constexpr (trans) { + static_assert(offset % KK == 0, "Offset must be a multiple of K"); + constexpr int offset_n = offset / KK; + return flat_divide(sa, Shape, Int>{})(_, _, _0{}, + Int{}); + } else { + static_assert(offset % NN == 0, "Offset must be a multiple of N"); + constexpr int offset_n = offset / NN; + return flat_divide(sa, Shape, Int>{})(_, _, Int{}, + _0{}); + } + } + } + + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + Tensor sA = get_region_tensor(sA_all); + Tensor sB = get_region_tensor(sB_all); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); + auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); + auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); + + Tensor tCrA = thr_mma.partition_fragment_A(sA); + Tensor tCrB = thr_mma.partition_fragment_B(sB); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCsB = thr_copy_B.partition_S(sB); + + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a + // workaround + auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); + auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); + if constexpr (clear_accum) { + clear(acc); + } + CUTE_UNROLL + for (int k = 0; k < size<2>(tCrA); ++k) { + copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); + copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); + gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); + } + } + + static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, + C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + Tensor sB = get_region_tensor(sB_all); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); + auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); + + Tensor tCrB = thr_mma.partition_fragment_B(sB); + Tensor tCsB = thr_copy_B.partition_S(sB); + + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + Tensor tCrA = + make_tensor(make_rmem_ptr(reinterpret_cast(pA)), + partition_shape_A(tiled_mma, Shape, Int>{})); + auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); + if constexpr (clear_accum) { + clear(acc); + } + copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tCrA); ++k) { + if (k < size<2>(tCrA) - 1) { + copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); + } + gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); + } + } + + static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB, + C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sA = get_region_tensor(sA_all); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); + + Tensor tCrA = thr_mma.partition_fragment_A(sA); + Tensor tCsA = thr_copy_A.partition_S(sA); + + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + Tensor tCrB = + make_tensor(make_rmem_ptr(reinterpret_cast(pB)), + partition_shape_B(tiled_mma, Shape, Int>{})); + auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); + if constexpr (clear_accum) { + clear(acc); + } + copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tCrA); ++k) { + if (k < size<2>(tCrA) - 1) { + copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); + } + gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); + } + } +}; + +} // namespace cute::tl_mma + +namespace tl::tl_mma { + +template +CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body(pA, pB, accum); +} + +template +CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_rs(pA, pB, accum); +} + +template +CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_sr(pA, pB, accum); +} + +} // namespace tl::tl_mma diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sm100.h b/tilelang/original/src/tl_templates/cuda/gemm_sm100.h new file mode 100644 index 0000000000000000000000000000000000000000..84e22f24e1f11804c192233f5971f87f9998f4db --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sm100.h @@ -0,0 +1,438 @@ +// Licensed under the MIT License. +#pragma once + +#include "common.h" +#include "gemm_mma.h" +#include "intrin.h" + +#include +#include +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support TCGEN5MMA with .ws, so we add it here +// About why we need .ws, plz refer to comments in tl_tcgen5mma::GemmTensorOp + +template +struct SM100_MMA_F16BF16_WS_SS { + static_assert(M == 32 || M == 64 || M == 128, + "SM100_MMA_F16BF16 (with .ws) M-mode size should be 32, 64 or " + "128 for 1 CTA cluster MMA."); + static_assert( + N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16 (with .ws) N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scaleC, uint64_t const &idescE) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE >> 32)), + "r"(scaleC)); + } + } +}; + +template +struct MMA_Traits> { + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && + cute::sizeof_bits_v == 16, + "SM100_MMA_F16BF16_WS_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + + UMMA::InstrDescriptor idesc_ = + UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), + idesc); + } +}; + +struct SM100_MMA_F8F6F4_WS_SS { + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scaleC, uint64_t const &idescE) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, " + "p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), + "r"(uint32_t(idescE >> 32)), "r"(scaleC)); + } + } +}; + +template +struct MMA_Traits, + cute::C, cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> { + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && + cute::sizeof_bits_v <= 8, + "SM100_MMA_F8F6F4_WS_SS supports types with leq 8bit types"); + static_assert(M == 32 || M == 64 || M == 128, + "SM100_MMA_F8F6F4_WS_SS M-mode size should be 32, 64 or 128 " + "for 1 CTA cluster MMA."); + static_assert( + N == 64 || N == 128 || N == 256, + "SM100_MMA_F8F6F4_WS_SS (with .ws) N-mode size should be 32, 64 or 128"); + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + + UMMA::InstrDescriptor idesc_ = + UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + +namespace tl_tcgen5mma { + +using cutlass::gemm::collective::detail::sm100_smem_selector; + +template +struct DispatchInstruction; + +template +struct DispatchInstruction> { + using MMA = SM100_MMA_F16BF16_SS; +}; + +template +struct DispatchInstruction> { + using MMA = SM100_MMA_F16BF16_WS_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + SM100_MMA_F16BF16_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + SM100_MMA_F16BF16_WS_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +class GemmTensorOp { +public: + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; + using A_type = + typename std::conditional::value, + tfloat32_t, A_type_cute>::type; + using B_type = + typename std::conditional::value, + tfloat32_t, B_type_cute>::type; + using C_type = C_type_raw; + + static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32); + + static constexpr UMMA::Major UmmaMajorA = + trans_A ? UMMA::Major::MN : UMMA::Major::K; + static constexpr UMMA::Major UmmaMajorB = + trans_B ? UMMA::Major::K : UMMA::Major::MN; + + using SmemLayoutAtomA = + decltype(sm100_smem_selector, Int>()); + using SmemLayoutAtomB = + decltype(sm100_smem_selector, Int>()); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + + static CUTE_DEVICE void body_ss(A_type_raw *pA, B_type_raw *pB, uint32_t pC, + uint64_t *umma_bar_ptr, bool clear_accum) { + Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + + // TODO (lei): Normal TCGEN5MMA (the one w/o ws) don't saturate all 128 + // lanes when M == 64 + // (see layout F in + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-f) + // So we use the .ws variant here + using MmaAtom = + typename DispatchInstruction::MMA; + auto tiled_mma = make_tiled_mma(MmaAtom{}, Layout>{}, + Tile, Int, Int>{}); + auto thr_mma = tiled_mma.get_slice(_0{}); + tiled_mma.accumulate_ = + clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + Tensor acc = partition_fragment_C(tiled_mma, Shape, Int>{}); + acc.data() = pC; + + Tensor sA_frag = thr_mma.partition_fragment_A(sA); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(sA_frag); ++k_block) { + cute::gemm(tiled_mma, sA_frag(_, _, k_block), sB_frag(_, _, k_block), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + cutlass::arch::umma_arrive(umma_bar_ptr); + } +}; + +} // namespace tl_tcgen5mma + +} // namespace cute + +namespace tl { + +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; + +// TODO (lei): Implement gemm_ts +// template +// TL_DEVICE void gemm_ts(A_type *pA, B_type *pB, C_type *accum, uint64_t +// *umma_bar_ptr) { +// } + +template +TL_DEVICE void tcgen5mma_gemm_ss(A_type *pA, B_type *pB, uint32_t accum, + Barrier_type *umma_bar_ptr, bool clear_accum) { + using MMA = + cute::tl_tcgen5mma::GemmTensorOp; + MMA::body_ss(pA, pB, accum, reinterpret_cast(umma_bar_ptr), + clear_accum); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sm120.h b/tilelang/original/src/tl_templates/cuda/gemm_sm120.h new file mode 100644 index 0000000000000000000000000000000000000000..122f56642aff68b89fae817777f57cf3f0c31fe5 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sm120.h @@ -0,0 +1,9 @@ +#pragma once + +#include "gemm_mma.h" + +namespace tl { +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sm70.h b/tilelang/original/src/tl_templates/cuda/gemm_sm70.h new file mode 100644 index 0000000000000000000000000000000000000000..75127727935ea1377c8f9bcfaaa64bac135831fe --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sm70.h @@ -0,0 +1,188 @@ +#pragma once + +#include +#include + +#include "common.h" + +using cutlass::gemm::GemmShape; + +// Primary template +// Add 128 bits padding when the last dim is a multiple of 256 bits +template +struct DispatchSharedMemoryLayoutA { + using Layout = + typename std::conditional::type; + static int constexpr Dim = transpose ? M : K; + static int constexpr Stride = + (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; +}; +template +struct DispatchSharedMemoryLayoutB { + using Layout = + typename std::conditional::type; + static int constexpr Dim = transpose ? K : N; + static int constexpr Stride = + (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; +}; + +// Partial specialization for half_t +template +struct DispatchSharedMemoryLayoutA::type> { + using Layout = + cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>; + static int constexpr Stride = M; +}; + +template +struct DispatchSharedMemoryLayoutA { + using Layout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>; + static int constexpr Stride = M; +}; + +template struct DispatchSharedMemoryLayoutB { + using Layout = + cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>; + static int constexpr Stride = N; +}; + +template +struct DispatchSharedMemoryLayoutB::type> { + using Layout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>; + static int constexpr Stride = N; +}; + +template +class GemmTensorOp { +public: + using A_type = A_type_raw; + using B_type = B_type_raw; + using C_type = C_type_raw; + using InstructionShape = GemmShape<16, 16, 4>; + using SMemLayoutA = + typename DispatchSharedMemoryLayoutA::Layout; + using SMemLayoutB = + typename DispatchSharedMemoryLayoutB::Layout; + static constexpr int stride_A = + DispatchSharedMemoryLayoutA::Stride; + static constexpr int stride_B = + DispatchSharedMemoryLayoutB::Stride; + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape, 32, A_type, + typename std::conditional::type, + B_type, + typename std::conditional::type, + C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; + + static_assert(Shape::kM % num_warp_m == 0); + static_assert(Shape::kN % num_warp_n == 0); + + using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp< + GemmShape, + A_type, SMemLayoutA, B_type, SMemLayoutB, C_type, + cutlass::layout::RowMajor, Policy>; + + using TensorRefA = typename MmaWarp::IteratorA::TensorRef; + using TensorRefB = typename MmaWarp::IteratorB::TensorRef; + using FragmentA = typename MmaWarp::FragmentA; + using FragmentB = typename MmaWarp::FragmentB; + using FragmentC = typename MmaWarp::FragmentC; + using IteratorA = typename MmaWarp::IteratorA; + using IteratorB = typename MmaWarp::IteratorB; + + static_assert(Shape::kK % InstructionShape::kK == 0); + static int constexpr kKgroups = Shape::kK / InstructionShape::kK; + + static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB, + FragmentC &accum, const int warp_idx_m, + const int warp_idx_n, const int lane_id) { + MmaWarp mma_op; + FragmentA frag_A; + FragmentB frag_B; + const TensorRefA ref_A((A_type *)pA, stride_A); + const TensorRefB ref_B((B_type *)pB, stride_B); + IteratorA iter_A(ref_A, lane_id); + IteratorB iter_B(ref_B, lane_id); + iter_A.add_tile_offset({warp_idx_m, 0}); + iter_B.add_tile_offset({0, warp_idx_n}); + if constexpr (clear_accum) { + accum.clear(); + } + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_A.load(frag_A); + iter_B.load(frag_B); + ++iter_A; + ++iter_B; + mma_op(accum, frag_A, frag_B, accum); + } + } + + static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB, + FragmentC &accum, const int warp_idx_n, + const int lane_id) { + MmaWarp mma_op; + FragmentB frag_B; + const TensorRefB ref_B((B_type *)pB, stride_B); + IteratorB iter_B(ref_B, lane_id); + iter_B.add_tile_offset({0, warp_idx_n}); + if constexpr (clear_accum) { + accum.clear(); + } + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_B.load(frag_B); + ++iter_B; + mma_op(accum, frag_A[k], frag_B, accum); + } + } +}; + +namespace tl { + +template +CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + using MMA = GemmTensorOp, num_warp_m, num_warp_n, trans_A, + trans_B, clear_accum, A_type, B_type, C_type>; + using FragmentC = typename MMA::FragmentC; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n, + warp_id % num_warp_n, lane_id); +} + +template +CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + using MMA = GemmTensorOp, num_warp_m, num_warp_n, trans_A, + trans_B, clear_accum, A_type, B_type, C_type>; + using FragmentA = typename MMA::FragmentA; + using FragmentC = typename MMA::FragmentC; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum), + warp_id % num_warp_n, lane_id); +} + +}; // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sm80.h b/tilelang/original/src/tl_templates/cuda/gemm_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..122f56642aff68b89fae817777f57cf3f0c31fe5 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sm80.h @@ -0,0 +1,9 @@ +#pragma once + +#include "gemm_mma.h" + +namespace tl { +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sm89.h b/tilelang/original/src/tl_templates/cuda/gemm_sm89.h new file mode 100644 index 0000000000000000000000000000000000000000..d64ae9e2e6986fadcd2ea62b10627ec9cf808710 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sm89.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +#include "cuda_fp8.h" + +#include "gemm_mma.h" + +namespace tl { +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sm90.h b/tilelang/original/src/tl_templates/cuda/gemm_sm90.h new file mode 100644 index 0000000000000000000000000000000000000000..543a29d096b7b3e1c14c5ca8532f5d344b7269db --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sm90.h @@ -0,0 +1,387 @@ +#pragma once + +#include "common.h" +#include "gemm_mma.h" +#include "intrin.h" + +#include +#include +#include + +namespace cute { + +using namespace SM90; + +namespace tl_wgmma { + +using namespace cutlass::gemm::collective::detail; // ss_smem_selector + +template +class GemmTensorOp { +public: + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; + using A_type = conditional_t::value, + tfloat32_t, A_type_cute>; + using B_type = conditional_t::value, + tfloat32_t, A_type_cute>; + using C_type = C_type_raw; + + static constexpr GMMA::Major GmmaMajorA = + trans_A ? GMMA::Major::MN : GMMA::Major::K; + static constexpr GMMA::Major GmmaMajorB = + trans_B ? GMMA::Major::K : GMMA::Major::MN; + + using SmemLayoutAtomA = + decltype(ss_smem_selector, Int>()); + using SmemLayoutAtomB = + decltype(ss_smem_selector, Int>()); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + + static_assert(num_warp_m % 4 == 0, + "num_warp_m must be a multiple of 4 for hopper wgmma"); + + template + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + auto tiled_mma = make_tiled_mma( + GMMA::ss_op_selector< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{}); + auto thr_mma = tiled_mma.get_thread_slice(tid); + + // Allocate registers for pipelining + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + } + + template + static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, + C_type_raw *pC) { + // TODO: Move bar.sync out of body_rs + // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * + // 32)); + const int tid = threadIdx.x; + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + auto tiled_mma = make_tiled_mma( + GMMA::rs_op_selector< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{}); + auto thr_mma = tiled_mma.get_thread_slice(tid); + + // Allocate registers for pipelining + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCrA = + make_tensor(make_rmem_ptr(reinterpret_cast(pA)), + partition_shape_A(tiled_mma, Shape, Int>{})); + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(tCrA); + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + warpgroup_fence_operand(tCrA); + } +}; + +} // namespace tl_wgmma + +} // namespace cute +/** + * Execute a tiled GEMM where A is read from global memory and B is staged in + * shared memory. + * + * Dispatches to tl_mma::GemmTensorOp::body_rs to perform the + * computation. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Execute a tiled GEMM where A is staged in shared memory and B is read from + * global memory. + * + * Dispatches to tl_mma::GemmTensorOp::body_sr to perform the + * computation. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Perform a tiled GEMM (both operands in shared memory or selected backend) and + * write to accum. + * + * If use_wgmma is true, validates wgmma constraints (strides and offsets) and + * dispatches to the Hopper wgmma implementation; otherwise dispatches to the + * tl_mma implementation. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Perform a tiled GEMM with A in global memory and B in shared memory (or + * selected backend). + * + * If use_wgmma is true, validates wgmma constraints (strides and offsets) and + * dispatches to the Hopper wgmma read-share implementation; otherwise + * dispatches to the tl_mma read-share. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Perform a tiled GEMM with A staged in shared memory and B in global memory + * (tl_mma only). + * + * wgmma does not support this variant; caller must set use_wgmma == false. + * Dispatches to tl_mma::GemmTensorOp::body_sr. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Wait for a warp-group of WMMA/MMA warps to complete. + * + * Wrapper around cute::warpgroup_wait for the specified number of MMA warps. + */ +/** + * Synchronize a named barrier across NumMmaThreads MMA threads. + * + * Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id. + */ +/** + * Arrive at a named barrier for NumMmaThreads MMA threads using + * architecture-aware mapping. + * + * Supported NumMmaThreads values: 256 or 384. The function issues one or two + * barrier arrives depending on the thread-group topology to ensure proper + * rendezvous ordering. + */ +/** + * Initialize named-barrier state for multi-warp MMA execution. + * + * For NumMmaThreads == 256 or 384, performs the required initial barrier + * arrivals for non-zero canonical warp-group indices to set up subsequent + * barrier synchronization. + */ + +namespace tl { + +template +TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + if constexpr (use_wgmma) { + static_assert((trans_A && lda == M) || (!trans_A && lda == K), + "Hopper wgmma doesn't support custom stride for A"); + static_assert((trans_B && ldb == K) || (!trans_B && ldb == N), + "Hopper wgmma doesn't support custom stride for B"); + static_assert(offset_a == 0 && offset_b == 0, + "offset_a and offset_b must be zero for wgmma"); + using MMA = cute::tl_wgmma::GemmTensorOp; + MMA::body(pA, pB, accum); + } else { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body(pA, pB, accum); + } +} + +template +TL_DEVICE /** + * Perform a read-share (B in shared memory, A in global) tiled GEMM + * and accumulate into `accum`. + * + * Dispatches at compile time to either the Hopper wgmma + * implementation or the fallback MMA implementation depending on + * `use_wgmma`. The selected GemmTensorOp::body_rs performs the + * region-tiled GEMM loop and updates the accumulator in-place. + * + * When `use_wgmma == true`, this function enforces wgmma constraints + * at compile time: + * - A's leading dimension must equal (trans_A ? M : K) + * - B's leading dimension must equal (trans_B ? K : N) + * - offset_a and offset_b must be zero + * + * @param pA Pointer to operand A (global memory). Layout/stride + * expectations depend on template parameters. + * @param pB Pointer to operand B (base for shared-memory staging). + * Layout/stride expectations depend on template parameters. + * @param accum Pointer to the accumulator/output C buffer updated + * in-place. + */ + void + gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + if constexpr (use_wgmma) { + static_assert((trans_A && lda == M) || (!trans_A && lda == K), + "Hopper wgmma doesn't support custom stride for A"); + static_assert((trans_B && ldb == K) || (!trans_B && ldb == N), + "Hopper wgmma doesn't support custom stride for B"); + static_assert(offset_a == 0 && offset_b == 0, + "offset_a and offset_b must be zero for wgmma"); + using MMA = cute::tl_wgmma::GemmTensorOp; + MMA::body_rs(pA, pB, accum); + } else { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_rs(pA, pB, accum); + } +} + +template +TL_DEVICE /** + * Perform a non-wgmma tiled GEMM where A regions are staged into + * shared memory and B is read directly from global memory, + * accumulating into `accum`. + * + * This overload dispatches to the tl_mma::GemmTensorOp::body_sr + * implementation. Must be instantiated with `use_wgmma = false` + * (enforced via static_assert). + * + * @param pA Pointer to the A operand in global memory (source that + * will be staged to shared memory). + * @param pB Pointer to the B operand in global memory (read + * directly). + * @param accum Pointer to the output accumulator matrix in global + * memory. + */ + void + gemm_sr(A_type *pA, B_type *pB, C_type *accum) { + static_assert(!use_wgmma, "wgmma doesn't support gemm_sr"); + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_sr(pA, pB, accum); +} + +template +TL_DEVICE /** + * Wait for all WMMA/MMA warps in the current warp-group to + * synchronize. + * + * Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes + * completes, ensuring all participating warps have arrived before + * proceeding. + */ + void + wait_wgmma() { + cute::warpgroup_wait(); +} + +template TL_DEVICE void warp_scheduler_barrier_sync() { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, + cutlass::canonical_warp_group_idx() /*id*/); +} + +template TL_DEVICE void warp_scheduler_barrier_arrive() { + static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); + if constexpr (NumMmaThreads == 256) { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/); + } else { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + (cutlass::canonical_warp_group_idx() <= 1 + ? cutlass::canonical_warp_group_idx() + 1 + : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + (cutlass::canonical_warp_group_idx() <= 0 + ? cutlass::canonical_warp_group_idx() + 2 + : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + } +} + +template TL_DEVICE void mma_init() { + static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); + if (cutlass::canonical_warp_group_idx() > 0) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0); + } + if constexpr (NumMmaThreads == 384) { + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/); + } + } +} +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sp.h b/tilelang/original/src/tl_templates/cuda/gemm_sp.h new file mode 100644 index 0000000000000000000000000000000000000000..f40a7bd0f8e2df261cb9aff94e2e0d9072438255 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sp.h @@ -0,0 +1,6 @@ +#pragma once +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#include "gemm_sp_sm90.h" +#else(defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) +#include "gemm_sp_sm80.h" +#endif diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sp_sm80.h b/tilelang/original/src/tl_templates/cuda/gemm_sp_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..f1fc860092e2816b21c16d97c6f2f3ce01344338 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sp_sm80.h @@ -0,0 +1,270 @@ +#include +#include + +namespace tl { + +static int const kSparse = 2; +template struct ShapeCheck { + static constexpr bool value = false; +}; + +template struct ShapeCheck { + static constexpr bool value = + (Shape::kM % 32 == 0) && (Shape::kN % 32 == 0) && (Shape::kK % 32 == 0); +}; + +template struct ShapeCheck { + static constexpr bool value = + ShapeCheck::value; // Same as half +}; + +template struct ShapeCheck { + static constexpr bool value = + (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); +}; + +template struct ShapeCheck { + static constexpr bool value = + (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); +}; + +// ref: +// https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h +template struct DispatchInstructionShape { + static_assert(!std::is_same_v, + "Unsupported type for DispatchInstructionShape"); +}; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 32>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 32>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// TODO: Not supported for now +// template<> +// struct DispatchInstructionShape { +// using Shape = cutlass::gemm::GemmShape<16, 8, 16>; +// using Operator = cutlass::arch::OpMultiplyAdd; +// }; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 64>; + using Operator = cutlass::arch::OpMultiplyAddSaturate; +}; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 64>; + using Operator = cutlass::arch::OpMultiplyAddSaturate; +}; + +// TODO: Not supported for now +// template<> +// struct DispatchInstructionShape { +// using Shape = cutlass::gemm::GemmShape<16, 8, 128>; +// using Operator = cutlass::arch::OpMultiplyAddSaturate; +// }; + +template +struct DispatchSharedMemoryLayoutA; + +template +struct DispatchSharedMemoryLayoutA { + using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, K / kSparse>; +}; + +template +struct DispatchSharedMemoryLayoutA { + static int const Crosswise_A = + cutlass::platform::min(int(128 / sizeof(T)), M); + using SmemLayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, Crosswise_A>; +}; + +template +struct DispatchSharedMemoryLayoutB; + +template +struct DispatchSharedMemoryLayoutB { + static_assert( + cutlass::sizeof_bits::value != 8, + "int8, uint8, float8 only support column major layout for matrix B"); + static int const Crosswise_B = + cutlass::platform::min(int(128 / sizeof(T)), N); + using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, Crosswise_B>; +}; + +template +struct DispatchSharedMemoryLayoutB { + static int const kCrosswiseB = (K > (1024 / cutlass::sizeof_bits::value)) + ? (1024 / cutlass::sizeof_bits::value) + : K; + using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, kCrosswiseB>; +}; + +template struct DispatchType { + static_assert(std::is_same::value, "Unsupported dtype"); +}; + +template <> struct DispatchType { + using Type = cutlass::half_t; +}; + +template <> struct DispatchType { + using Type = cutlass::bfloat16_t; +}; + +template <> struct DispatchType { + using Type = uint8_t; +}; + +template <> struct DispatchType { + using Type = int8_t; +}; + +template +class GemmTensorOp { +public: + static_assert(Shape::kM % num_warp_m == 0); + static_assert(Shape::kN % num_warp_n == 0); + using ElementA = typename DispatchType::Type; + using ElementB = typename DispatchType::Type; + using ElementC = C_type_raw; + + static_assert(std::is_same_v, + "A and B are not the same type"); + static_assert(ShapeCheck::value, + "Invalid shape for ElementA"); + + using LayoutA = + typename std::conditional_t; + using LayoutB = + typename std::conditional_t; + using LayoutC = cutlass::layout::RowMajor; + using ThreadblockShape = Shape; + using SmemLayoutA = + typename DispatchSharedMemoryLayoutA::SmemLayoutA; + using SmemLayoutB = + typename DispatchSharedMemoryLayoutB::SmemLayoutB; + + using WarpShape = cutlass::gemm::GemmShape; + using InstructionShape = typename DispatchInstructionShape::Shape; + using Operator = typename DispatchInstructionShape::Operator; + static_assert(WarpShape::kK % InstructionShape::kK == 0, + "K dimension must be divisible by instruction shape K."); + + // instruction/warp config + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::SparseMma, + cutlass::MatrixShape<1, 1>>; + using MmaWarp = + cutlass::gemm::warp::SparseMmaTensorOp; + static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); + + using SmemLayoutE = typename MmaWarp::LayoutE; + static_assert(std::is_same_v, + "Meta data layout must be ColumnMajor for sparse mma."); + + // other traits + using FragmentA = typename MmaWarp::FragmentA; + using FragmentB = typename MmaWarp::FragmentB; + using FragmentC = typename MmaWarp::FragmentC; + using FragmentE = typename MmaWarp::FragmentE; + + using IteratorA = typename MmaWarp::IteratorA; + using IteratorB = typename MmaWarp::IteratorB; + using IteratorE = typename MmaWarp::IteratorE; + + using TensorRefA = typename IteratorA::TensorRef; + using TensorRefB = typename IteratorB::TensorRef; + using TensorRefE = typename IteratorE::TensorRef; + using ElementE = typename TensorRefE::Element; + + static int const kElementsPerElementE = MmaWarp::kElementsPerElementE; + static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); + + using ShapeA = cutlass::MatrixShape; + using ShapeB = cutlass::MatrixShape; + using ShapeE = + cutlass::MatrixShape; + + static int constexpr kKgroups = WarpShape::kK / InstructionShape::kK; + + template + static CUTLASS_DEVICE void + body(A_type_raw *pA, E_type_raw *pE, B_type_raw *pB, FragmentC &accum, + const int warp_idx_m, const int warp_idx_n, const int lane_id) { + MmaWarp mma_op; + FragmentA frag_a; + FragmentB frag_b; + FragmentE frag_e; + const TensorRefA ref_A( + (ElementA *)pA, + MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn})); + const TensorRefE ref_E( + (ElementE *)pE, + MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn})); + const TensorRefB ref_B( + (ElementB *)pB, + MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn})); + IteratorA iter_A(ref_A, lane_id); + IteratorE iter_E(ref_E, lane_id); + IteratorB iter_B(ref_B, lane_id); + iter_A.add_tile_offset({warp_idx_m, 0}); + iter_E.add_tile_offset({warp_idx_m, 0}); + iter_B.add_tile_offset({0, warp_idx_n}); + if constexpr (clear_accum) { + accum.clear(); + } + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_A.load(frag_a); + iter_E.load(frag_e); + iter_B.load(frag_b); + ++iter_A; + ++iter_E; + ++iter_B; + mma_op(accum, frag_a, frag_b, accum, frag_e); + } + } +}; + +template +TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { + using MMA = + GemmTensorOp, num_warp_m, num_warp_n, + trans_A, trans_B, clear_accum, A_type, B_type, C_type>; + using FragmentC = typename MMA::FragmentC; + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, + warp_id / num_warp_m, lane_id); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/gemm_sp_sm90.h b/tilelang/original/src/tl_templates/cuda/gemm_sp_sm90.h new file mode 100644 index 0000000000000000000000000000000000000000..6184f9be7a02ba0feaee897c0eea887db2e51a3a --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/gemm_sp_sm90.h @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include + +namespace cute { +namespace tl_wgmma_sp { +template +class GemmTensorOp { +public: + static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); + + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; + using A_type = conditional_t::value, + tfloat32_t, A_type_cute>; + using B_type = conditional_t::value, + tfloat32_t, B_type_cute>; + using C_type = C_type_raw; + + static constexpr bool need_tfloat32_cast = + std::is_same::value && + std::is_same::value; + + static constexpr GMMA::Major GmmaMajorA = + trans_A ? GMMA::Major::MN : GMMA::Major::K; + static constexpr GMMA::Major GmmaMajorB = + trans_B ? GMMA::Major::K : GMMA::Major::MN; + + using TiledMma = decltype(make_tiled_mma( + GMMA::ss_op_selector_sparse< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementBMma = typename TiledMma::ValTypeB; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementEMmaSparsity = Int; + using E_type_raw = typename ElementEMma::raw_type; + + using SparseConfig = + cutlass::Sm90GemmSparseConfig{}, _128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + + using SmemLayoutAtomA = + decltype(cutlass::gemm::collective::detail::ss_smem_selector_sparse< + GmmaMajorA, A_type, Int, Int, ElementAMmaSparsity>()); + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, B_type, Int, Int>()); + + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = + ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + + using SmemCopyAtomE = AutoVectorizingCopy; + + template + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC, + E_type_raw *pE) { + const int tid = threadIdx.x; + Tensor sA = + make_tensor(make_smem_ptr(recast_ptr(pA)), SmemLayoutA{}); + Tensor sB = + make_tensor(make_smem_ptr(recast_ptr(pB)), SmemLayoutB{}); + Tensor sE = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(recast_ptr(pE)), SmemLayoutE{})); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + + Tensor tCsA = thr_mma.partition_A(sA); + Tensor tCsB = thr_mma.partition_B(sB); + Tensor tCsE = partition_E(thr_mma, sE(_, _)); + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); + Tensor tCrB = thr_mma.make_fragment_B(tCsB); + Tensor tCrE = make_fragment_like(tCsE); + + auto copy_atom_E = Copy_Atom{}; + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(tid); + Tensor tEsE = smem_thr_copy_E.partition_S(sE); + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); + + Tensor acc = + make_tensor(make_rmem_ptr(pC), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + copy(smem_tiled_copy_E, tEsE, tErE); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, make_zip_tensor(tCrA(_, _, k_block), tCrE(_, _, k_block)), + tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + } + + template + CUTE_HOST_DEVICE static constexpr auto + thrfrg_E(TiledMMA const &mma, + ETensor &&etensor) { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = + make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = + zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = + e_tensor.compose(AtomLayoutE_TV{}, _); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = + make_tile(_, make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide( + tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr auto + get_layoutE_TV(TiledMMA const &mma) { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile( + _, make_tile(make_layout(make_shape(size<1>(mma.thr_layout_vmnk_), + size<2>(mma.thr_layout_vmnk_)), + make_stride(Int<1>{}, Int<0>{})), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr auto + partition_E(ThrMMA const &thr_mma, ETensor &&etensor) { + auto thr_tensor = make_tensor(static_cast(etensor).data(), + thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord( + get<0>(thr_mma.thr_vmnk_), + make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, + make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr auto + make_tiled_copy_E(Copy_Atom const ©_atom, + TiledMMA const &mma) { + return make_tiled_copy_impl( + copy_atom, get_layoutE_TV(mma), + make_shape(tile_size<0>(mma), tile_size<2>(mma))); + } +}; + +} // namespace tl_wgmma_sp +} // namespace cute + +namespace tl { +template , + typename E_type = typename GMMA::ElementEMma::raw_type> +TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { + static_assert(use_wgmma, "only wgmma is supported for now"); + if constexpr (use_wgmma) { + GMMA::body(pA, pB, accum, pE); + } else { + CUTE_GCC_UNREACHABLE; + } +} +} // namespace tl \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/cuda/instruction/mma.h b/tilelang/original/src/tl_templates/cuda/instruction/mma.h new file mode 100644 index 0000000000000000000000000000000000000000..869fa777bc1cdb72a62a5645c135c5393e308d96 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/instruction/mma.h @@ -0,0 +1,165 @@ +#pragma once + +#include "../common.h" +#include +#include + +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +namespace tl { + +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED +template inline constexpr bool always_false_v = false; +#endif + +namespace detail { + +template struct MmaImplTraits { + using DReg = std::remove_extent_t; + using AReg = std::remove_extent_t; + using BReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + + static constexpr int kDRegs = std::extent_v; + static constexpr int kARegs = std::extent_v; + static constexpr int kBRegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; +}; + +template +TL_DEVICE void +call_fma_impl(typename MmaImplTraits::DReg *d, + const typename MmaImplTraits::AReg *a, + const typename MmaImplTraits::BReg *b, + const typename MmaImplTraits::CReg *c, + std::index_sequence, std::index_sequence, + std::index_sequence, std::index_sequence) { + Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...); +} + +template +TL_DEVICE void call_fma(typename MmaImplTraits::DReg *d, + const typename MmaImplTraits::AReg *a, + const typename MmaImplTraits::BReg *b, + const typename MmaImplTraits::CReg *c) { + call_fma_impl(d, a, b, c, + std::make_index_sequence::kDRegs>{}, + std::make_index_sequence::kARegs>{}, + std::make_index_sequence::kBRegs>{}, + std::make_index_sequence::kCRegs>{}); +} + +template +struct MmaDispatcher { + using CRegType = void; + using ARegType = void; + using BRegType = void; + + static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *, + const CRegType *) { + static_assert(always_false_v>, + "tl::mma_sync: unsupported configuration"); + } +}; + +#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \ + NValue, KValue, TransAValue, TransBValue, \ + SaturateValue, ImplType) \ + template <> \ + struct MmaDispatcher { \ + using Impl = ImplType; \ + using Traits = MmaImplTraits; \ + using CRegType = typename Traits::DReg; \ + using ARegType = typename Traits::AReg; \ + using BRegType = typename Traits::BReg; \ + static_assert( \ + std::is_same_v, \ + "tl::mma_sync requires matching accumulator/output regs"); \ + static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ + const BRegType *b, const CRegType *c) { \ + call_fma(d, a, b, c); \ + } \ + }; + +// FP16 inputs (TN layout: A row-major, B column-major) +TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true, + false, cute::SM80_16x8x16_F16F16F16F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true, + false, cute::SM80_16x8x16_F32F16F16F32_TN) + +// BF16 inputs +TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true, + false, cute::SM80_16x8x16_F32BF16BF16F32_TN) + +// INT8 inputs (k32) +TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32S8S8S32_TN) +TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32U8U8S32_TN) + +// INT4 inputs (k32) +TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32S4S4S32_TN) +TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32U4U4S32_TN) + +// FP8 inputs (k32) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN) + +// TF32 inputs (FP32 math on Tensor Cores) +// Support both k=4 and k=8 variants on SM80 +TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 4, + false, true, false, + cute::SM80_16x8x4_F32TF32TF32F32_TN) +TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8, + false, true, false, + cute::SM80_16x8x8_F32TF32TF32F32_TN) + +// FP64 inputs (DMMA: m8n8k4, TN layout) +TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true, + false, cute::SM80_8x8x4_F64F64F64F64_TN) + +#undef TL_DEFINE_MMA_DISPATCHER + +} // namespace detail + +template +TL_DEVICE void mma_sync( + typename detail::MmaDispatcher::CRegType *c, + const typename detail::MmaDispatcher::ARegType *a, + const typename detail::MmaDispatcher::BRegType *b) { + using Dispatcher = detail::MmaDispatcher; + static_assert(!std::is_void_v, + "tl::mma_sync: unsupported configuration"); + Dispatcher::exec(c, a, b, c); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/instruction/mma_sm70.h b/tilelang/original/src/tl_templates/cuda/instruction/mma_sm70.h new file mode 100644 index 0000000000000000000000000000000000000000..7a44b92124d7d8d454c8bbaa1693fdf1ef1f64d1 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/instruction/mma_sm70.h @@ -0,0 +1,355 @@ +#pragma once + +#include "../common.h" + +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +namespace tl { + +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED +template inline constexpr bool always_false_v = false; +#endif + +namespace detail { + +// SM70 MMA Instruction Traits and Implementations +// SM70 supports m16n16k4 (m8n8k4 instruction at warp level) with FP16/FP32 +// accumulation + +// Base template for SM70 MMA implementation +template +struct MmaSm70Impl { + // Default: unsupported configuration + static constexpr bool kSupported = false; + + static TL_DEVICE void exec(void *, const void *, const void *, const void *) { + static_assert(always_false_v>, + "tl::mma_sync_sm70: unsupported configuration"); + } +}; + +// FP16 inputs, FP16 accumulation - col.col (TransA=true, TransB=true) +template <> +struct MmaSm70Impl { + using DRegisters = unsigned[4]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = unsigned[4]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2, + unsigned &d3, unsigned a0, unsigned a1, unsigned b0, + unsigned b1, unsigned c0, unsigned c1, unsigned c2, + unsigned c3) { + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), + "r"(c2), "r"(c3)); + } +}; + +// FP16 inputs, FP16 accumulation - col.row (TransA=true, TransB=false) +template <> +struct MmaSm70Impl { + using DRegisters = unsigned[4]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = unsigned[4]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2, + unsigned &d3, unsigned a0, unsigned a1, unsigned b0, + unsigned b1, unsigned c0, unsigned c1, unsigned c2, + unsigned c3) { + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), + "r"(c2), "r"(c3)); + } +}; + +// FP16 inputs, FP16 accumulation - row.col (TransA=false, TransB=true) +template <> +struct MmaSm70Impl { + using DRegisters = unsigned[4]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = unsigned[4]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2, + unsigned &d3, unsigned a0, unsigned a1, unsigned b0, + unsigned b1, unsigned c0, unsigned c1, unsigned c2, + unsigned c3) { + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), + "r"(c2), "r"(c3)); + } +}; + +// FP16 inputs, FP16 accumulation - row.row (TransA=false, TransB=false) +template <> +struct MmaSm70Impl { + using DRegisters = unsigned[4]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = unsigned[4]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2, + unsigned &d3, unsigned a0, unsigned a1, unsigned b0, + unsigned b1, unsigned c0, unsigned c1, unsigned c2, + unsigned c3) { + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), + "r"(c2), "r"(c3)); + } +}; + +// FP16 inputs, FP32 accumulation - col.col (TransA=true, TransB=true) +template <> +struct MmaSm70Impl { + using DRegisters = float[8]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = float[8]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3, + float &d4, float &d5, float &d6, float &d7, + unsigned a0, unsigned a1, unsigned b0, unsigned b1, + float c0, float c1, float c2, float c3, float c4, + float c5, float c6, float c7) { + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5), + "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7)); + } +}; + +// FP16 inputs, FP32 accumulation - col.row (TransA=true, TransB=false) +template <> +struct MmaSm70Impl { + using DRegisters = float[8]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = float[8]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3, + float &d4, float &d5, float &d6, float &d7, + unsigned a0, unsigned a1, unsigned b0, unsigned b1, + float c0, float c1, float c2, float c3, float c4, + float c5, float c6, float c7) { + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 " + "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5), + "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7)); + } +}; + +// FP16 inputs, FP32 accumulation - row.col (TransA=false, TransB=true) +template <> +struct MmaSm70Impl { + using DRegisters = float[8]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = float[8]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3, + float &d4, float &d5, float &d6, float &d7, + unsigned a0, unsigned a1, unsigned b0, unsigned b1, + float c0, float c1, float c2, float c3, float c4, + float c5, float c6, float c7) { + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5), + "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7)); + } +}; + +// FP16 inputs, FP32 accumulation - row.row (TransA=false, TransB=false) +template <> +struct MmaSm70Impl { + using DRegisters = float[8]; + using ARegisters = unsigned[2]; + using BRegisters = unsigned[2]; + using CRegisters = float[8]; + + static constexpr bool kSupported = true; + + static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3, + float &d4, float &d5, float &d6, float &d7, + unsigned a0, unsigned a1, unsigned b0, unsigned b1, + float c0, float c1, float c2, float c3, float c4, + float c5, float c6, float c7) { + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 " + "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5), + "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7)); + } +}; + +// Helper to extract register types +template struct MmaSm70ImplTraits { + using DReg = std::remove_extent_t; + using AReg = std::remove_extent_t; + using BReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + + static constexpr int kDRegs = std::extent_v; + static constexpr int kARegs = std::extent_v; + static constexpr int kBRegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; +}; + +// Dispatcher for SM70 MMA operations +template +struct MmaSm70Dispatcher { + using CRegType = void; + using ARegType = void; + using BRegType = void; + + static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *, + const CRegType *) { + static_assert(always_false_v>, + "tl::mma_sync_sm70: unsupported configuration. " + "SM70 only supports m16n16k4 with FP16 inputs and FP16/FP32 " + "accumulation."); + } +}; + +// Helper to call fma with unpacked register arrays +template +TL_DEVICE void +call_fma_impl_sm70(typename MmaSm70ImplTraits::DReg *d, + const typename MmaSm70ImplTraits::AReg *a, + const typename MmaSm70ImplTraits::BReg *b, + const typename MmaSm70ImplTraits::CReg *c, + std::index_sequence, std::index_sequence, + std::index_sequence, std::index_sequence) { + Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...); +} + +template +TL_DEVICE void call_fma_sm70(typename MmaSm70ImplTraits::DReg *d, + const typename MmaSm70ImplTraits::AReg *a, + const typename MmaSm70ImplTraits::BReg *b, + const typename MmaSm70ImplTraits::CReg *c) { + call_fma_impl_sm70( + d, a, b, c, std::make_index_sequence::kDRegs>{}, + std::make_index_sequence::kARegs>{}, + std::make_index_sequence::kBRegs>{}, + std::make_index_sequence::kCRegs>{}); +} + +// Define dispatchers for all supported SM70 configurations +// Note: m8n8k4 instruction computes m16n16k4 at warp level +#define TL_DEFINE_MMA_SM70_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, \ + TransAValue, TransBValue) \ + template <> \ + struct MmaSm70Dispatcher { \ + using Impl = MmaSm70Impl; \ + using Traits = MmaSm70ImplTraits; \ + using CRegType = typename Traits::DReg; \ + using ARegType = typename Traits::AReg; \ + using BRegType = typename Traits::BReg; \ + static_assert( \ + std::is_same_v, \ + "tl::mma_sync_sm70 requires matching accumulator/output regs"); \ + static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ + const BRegType *b, const CRegType *c) { \ + call_fma_sm70(d, a, b, c); \ + } \ + }; + +// FP16 inputs with FP16 accumulation (all layout combinations) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, true) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, false) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, true) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, false) + +// FP16 inputs with FP32 accumulation (all layout combinations) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, true) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, false) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, true) +TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, false) + +#undef TL_DEFINE_MMA_SM70_DISPATCHER + +} // namespace detail + +/// SM70 MMA synchronous instruction wrapper +/// Supports m16n16k4 shape (m8n8k4 instruction at warp level) with FP16 inputs +/// and FP16/FP32 accumulation +/// +/// @tparam AType Input A data type (kFloat16) +/// @tparam BType Input B data type (kFloat16) +/// @tparam CType Accumulator/output data type (kFloat16 or kFloat32) +/// @tparam M Matrix M dimension (16) +/// @tparam N Matrix N dimension (16) +/// @tparam K Matrix K dimension (4) +/// @tparam TransA Whether A is transposed (false=row-major, true=col-major) +/// @tparam TransB Whether B is transposed (false=row-major, true=col-major) +template +TL_DEVICE void mma_sync_sm70( + typename detail::MmaSm70Dispatcher::CRegType *c, + const typename detail::MmaSm70Dispatcher::ARegType *a, + const typename detail::MmaSm70Dispatcher::BRegType *b) { + using Dispatcher = + detail::MmaSm70Dispatcher; + static_assert(!std::is_void_v, + "tl::mma_sync_sm70: unsupported configuration. " + "SM70 only supports m16n16k4 with FP16 inputs."); + Dispatcher::exec(c, a, b, c); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/instruction/tcgen05mma.h b/tilelang/original/src/tl_templates/cuda/instruction/tcgen05mma.h new file mode 100644 index 0000000000000000000000000000000000000000..9772d6438291b74bb1d62ab660815735b3b9268b --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/instruction/tcgen05mma.h @@ -0,0 +1,337 @@ +#pragma once + +#include "../common.h" +#include + +namespace tl { + +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED +template inline constexpr bool always_false_v = false; +#endif + +// Generic declaration: unsupported by default +template +TL_DEVICE void +tcgen05mma_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/, + uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/, + uint32_t const & /*desc_val*/, int const & /*mask0*/, + int const & /*mask1*/, int const & /*mask2*/, + int const & /*mask3*/) { + static_assert( + always_false_v(C_type)>>, + "tl::tcgen05mma_ss: unsupported accumulator type"); +} + +// TS variants: A from TMEM, B from SMEM (desc) +// Generic declaration: unsupported by default +template +TL_DEVICE void +tcgen05mma_ts(uint32_t const & /*tmem_a*/, uint64_t const & /*desc_b*/, + uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/, + uint32_t const & /*desc_val*/, int const & /*mask0*/, + int const & /*mask1*/, int const & /*mask2*/, + int const & /*mask3*/) { + static_assert( + always_false_v(C_type)>>, + "tl::tcgen05mma_ts: unsupported accumulator type"); +} + +// F16/BF16 instruction kind (maps to kind::f16) +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, " + "%6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +// BF16 maps to the same f16-kind instruction +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ts(tmem_a, desc_b, tmem_c, scalec, desc_val, + mask0, mask1, mask2, mask3); +} + +// TF32 instruction kind +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, " + "%6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +// INT8 instruction kind +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, " + "%6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +// FP8 family instruction kind (maps to f8f6f4) +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, " + "{%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ts(tmem_a, desc_b, tmem_c, scalec, + desc_val, mask0, mask1, mask2, mask3); +} + +// F16/BF16 instruction kind (maps to kind::f16) +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + // idescE upper 32 bits carry the instruction descriptor; lower 32 ignored for + // SS Load TMEM base from shared memory slot handled by caller + + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, " + "%6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +// BF16 maps to the same f16-kind instruction +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ss(desc_a, desc_b, tmem_c, scalec, desc_val, + mask0, mask1, mask2, mask3); +} + +// TF32 instruction kind +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, " + "%6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +// INT8 instruction kind +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, " + "%7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +// FP8 family instruction kind (maps to f8f6f4) +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, " + "%6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), + "r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)); + } +} + +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ss(desc_a, desc_b, tmem_c, scalec, + desc_val, mask0, mask1, mask2, mask3); +} + +// WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx +// Generic declaration falls back to static assert +template +TL_DEVICE void +tcgen05mma_ws_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/, + uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/, + uint32_t const & /*desc_val*/, int const & /*mask0*/, + int const & /*mask1*/, int const & /*mask2*/, + int const & /*mask3*/) { + static_assert( + always_false_v(C_type)>>, + "tl::tcgen05mma_ws_ss: unsupported accumulator type"); +} + +// F16/BF16 ws +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec)); + } +} + +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ws_ss(desc_a, desc_b, tmem_c, scalec, desc_val, + mask0, mask1, mask2, mask3); +} + +// TF32 ws +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::tf32 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec)); + } +} + +// INT8 ws +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::i8 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec)); + } +} + +// FP8 ws (maps to f8f6f4) +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec)); + } +} + +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ws_ss( + desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/instruction/wgmma.h b/tilelang/original/src/tl_templates/cuda/instruction/wgmma.h new file mode 100644 index 0000000000000000000000000000000000000000..3af2d79fe0c49941d2f993e1faa0e3b31fb30d07 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/instruction/wgmma.h @@ -0,0 +1,473 @@ +#pragma once + +#include "../common.h" +#include +#include + +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +namespace tl { + +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED +template inline constexpr bool always_false_v = false; +#endif + +namespace detail { + +template struct MajorValue { + static constexpr auto value = + IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; +}; + +template struct ScaleInValue { + static_assert(Scale == 1 || Scale == -1, + "tl::wgmma requires scale factors of +1 or -1."); + static constexpr auto value = Scale == 1 ? cute::SM90::GMMA::ScaleIn::One + : cute::SM90::GMMA::ScaleIn::Neg; +}; + +template +inline constexpr bool IsValidScale = (Scale == 1 || Scale == -1); + +template struct CallWgmmaSS { + using CReg = std::remove_extent_t; + static constexpr int kCRegs = std::extent_v; + static_assert(sizeof(CReg) == sizeof(uint32_t), + "tl::wgmma_ss expects 32-bit accumulator registers."); + + template + TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c, + cute::SM90::GMMA::ScaleOut scale, + std::index_sequence) { + Impl::fma(desc_a, desc_b, c[Idx]..., scale); + } + + TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw, + bool scale_out) { + auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One + : cute::SM90::GMMA::ScaleOut::Zero; + auto c = reinterpret_cast(c_raw); + Run(desc_a, desc_b, c, scale, std::make_index_sequence{}); + } +}; + +template struct CallWgmmaRS { + using AReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + static constexpr int kARegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; + static_assert(sizeof(AReg) == sizeof(uint32_t), + "tl::wgmma_rs expects 32-bit register operands for A."); + static_assert(sizeof(CReg) == sizeof(uint32_t) || + sizeof(CReg) == sizeof(float), + "tl::wgmma_rs expects 32-bit accumulator registers."); + + template + TL_DEVICE static void + Run(const AReg *a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale, + std::index_sequence, std::index_sequence) { + Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale); + } + + TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b, + uint32_t *c_raw, bool scale_out) { + auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One + : cute::SM90::GMMA::ScaleOut::Zero; + auto a = reinterpret_cast(a_raw); + auto c = reinterpret_cast(c_raw); + Run(a, desc_b, c, scale, std::make_index_sequence{}, + std::make_index_sequence{}); + } +}; + +} // namespace detail + +template +struct WgmmaSSImpl { + static_assert(detail::IsValidScale, "tl::wgmma_ss: invalid scaleA"); + static_assert(detail::IsValidScale, "tl::wgmma_ss: invalid scaleB"); + TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool) { + static_assert(always_false_v>, + "tl::wgmma_ss: unsupported configuration"); + } +}; + +template +struct WgmmaRSImpl { + static_assert(detail::IsValidScale, "tl::wgmma_rs: invalid scaleA"); + static_assert(detail::IsValidScale, "tl::wgmma_rs: invalid scaleB"); + TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool) { + static_assert(always_false_v>, + "tl::wgmma_rs: unsupported configuration"); + } +}; + +#define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleB"); \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ + } \ + }; + +#define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleB"); \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ + } \ + }; + +#define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ + ImplName) \ + template \ + struct WgmmaSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleB"); \ + static_assert(scaleA == 1 && scaleB == 1, \ + "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \ + using Impl = cute::SM90::GMMA::ImplName; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ + } \ + }; + +#define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaRSImpl { \ + static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleB"); \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ + } \ + }; + +#define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaRSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleB"); \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ + } \ + }; + +#define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ + ImplName) \ + template \ + struct WgmmaRSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleB"); \ + static_assert(scaleA == 1 && scaleB == 1, \ + "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \ + using Impl = cute::SM90::GMMA::ImplName; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ + } \ + }; + +#define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \ + OP(8) \ + OP(16) \ + OP(24) \ + OP(32) \ + OP(40) \ + OP(48) \ + OP(56) \ + OP(64) \ + OP(72) \ + OP(80) \ + OP(88) \ + OP(96) \ + OP(104) \ + OP(112) \ + OP(120) \ + OP(128) \ + OP(136) \ + OP(144) \ + OP(152) \ + OP(160) \ + OP(168) \ + OP(176) \ + OP(184) \ + OP(192) \ + OP(200) \ + OP(208) \ + OP(216) \ + OP(224) \ + OP(232) \ + OP(240) \ + OP(248) \ + OP(256) + +#define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \ + OP(8) \ + OP(16) \ + OP(24) \ + OP(32) \ + OP(48) \ + OP(64) \ + OP(80) \ + OP(96) \ + OP(112) \ + OP(128) \ + OP(144) \ + OP(160) \ + OP(176) \ + OP(192) \ + OP(208) \ + OP(224) \ + OP(240) \ + OP(256) + +#define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \ + TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \ + MMA_64x##N##x16_F16F16F16_SS) +#define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \ + TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32F16F16_SS) +#define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \ + TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32BF16BF16_SS) + +#define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \ + MMA_64x##N##x8_F32TF32TF32_SS_TN) + +#define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8S8_SS_TN) +#define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8U8_SS_TN) +#define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8S8_SS_TN) +#define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8U8_SS_TN) + +#define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E5M2_SS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E5M2_SS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E5M2_SS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E5M2_SS_TN) + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_SS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_SS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_SS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_SS_TN); + +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_SS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_SS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_SS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_SS_TN); + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN); + +#define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \ + TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \ + MMA_64x##N##x16_F16F16F16_RS) +#define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \ + TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32F16F16_RS) +#define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \ + TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32BF16BF16_RS) + +#define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \ + MMA_64x##N##x8_F32TF32TF32_RS_TN) + +#define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8S8_RS_TN) +#define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8U8_RS_TN) +#define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8S8_RS_TN) +#define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8U8_RS_TN) + +#define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E5M2_RS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E5M2_RS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E5M2_RS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E5M2_RS_TN) + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_RS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_RS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_RS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_RS_TN); + +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_RS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_RS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_RS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_RS_TN); + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN); + +#undef TL_WGMMA_DEFINE_F16_F16_F16_SS +#undef TL_WGMMA_DEFINE_F16_F16_F32_SS +#undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS +#undef TL_WGMMA_DEFINE_F32_TF32_SS_TN +#undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN +#undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN +#undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN +#undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F16_F16_F16_RS +#undef TL_WGMMA_DEFINE_F16_F16_F32_RS +#undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS +#undef TL_WGMMA_DEFINE_F32_TF32_RS_TN +#undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN +#undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN +#undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN +#undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN +#undef TL_WGMMA_FOREACH_N_FLOAT_MUL8 +#undef TL_WGMMA_FOREACH_N_INT32_MUL8 +#undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE +#undef TL_WGMMA_DEFINE_SS_GENERAL +#undef TL_WGMMA_DEFINE_SS_TN +#undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE +#undef TL_WGMMA_DEFINE_RS_GENERAL +#undef TL_WGMMA_DEFINE_RS_TN + +template +TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + WgmmaSSImpl::execute(desc_a, desc_b, c, scale_out); +} + +template +TL_DEVICE void wgmma_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + WgmmaRSImpl::execute(a, desc_b, c, scale_out); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/intrin.h b/tilelang/original/src/tl_templates/cuda/intrin.h new file mode 100644 index 0000000000000000000000000000000000000000..0d5b5639deec8ee30ee5d337319248ca606d6ffd --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/intrin.h @@ -0,0 +1,133 @@ +#pragma once + +#include "common.h" +#include "cutlass/cutlass.h" + +#if __CUDA_ARCH_LIST__ >= 900 +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/mma_sm90_gmma.hpp" +#endif + +namespace tl { + +namespace detail { + +// Provide architecture-specific defaults so callers may omit arguments. +TL_DEVICE constexpr int default_warp_size() { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__) + return 64; +#else + return 32; +#endif +} + +TL_DEVICE constexpr int default_warps_per_group() { return 4; } + +TL_DEVICE int linear_thread_idx_in_block() { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); +#else + return 0; +#endif +} + +} // namespace detail + +TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + return detail::linear_thread_idx_in_block() % warp_size; +} + +TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + return detail::linear_thread_idx_in_block() / warp_size; +} + +TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + return detail::linear_thread_idx_in_block() / warp_size; +} + +TL_DEVICE int +get_warp_group_idx(int warp_size = detail::default_warp_size(), + int warps_per_group = detail::default_warps_per_group()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + warps_per_group = + warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group(); + int threads_per_group = warp_size * warps_per_group; + threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size; + return detail::linear_thread_idx_in_block() / threads_per_group; +} + +#if __CUDA_ARCH_LIST__ >= 900 +TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); } +TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); } + +template TL_DEVICE void warpgroup_wait() { + cute::warpgroup_wait(); +} + +TL_DEVICE void warpgroup_fence_operand(uint32_t *regs, int count) { +#pragma unroll + for (int i = 0; i < count; ++i) { + cute::warpgroup_fence_operand(regs[i]); + } +} + +TL_DEVICE void warpgroup_fence_operand(float *regs, int count) { +#pragma unroll + for (int i = 0; i < count; ++i) { + cute::warpgroup_fence_operand(regs[i]); + } +} + +// Template parameter: +// thread_extent: the logical size (in number of threads) of each "group" +// within which we want to elect exactly ONE representative +// thread. +template TL_DEVICE bool tl_shuffle_elect() { + + // Special case: thread_extent == 0 means "elect exactly one thread + // in the entire thread block", i.e., the leader of the first warp of the + // block. + if constexpr (thread_extent == 0) { + // cutlass::canonical_warp_idx_sync(): + // Returns the warp ID within the thread block in a "canonical" way + // (0 for the first warp, 1 for the second, ...). + // cute::elect_one_sync(): + // Elect exactly one lane in the warp to return true (typically lane 0), + // other lanes return false. + // The condition ensures that: + // (1) We are in warp 0 of the block. + // (2) We are the elected lane in this warp. + return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync(); + } + + // General case: thread_extent != 0 + // (threadIdx.x / 32) is the warp index in the block. + // (thread_extent / 32) is the number of warps in one group of size + // thread_extent. We take warp_id % num_warps_in_group to get the warp's index + // within the group. + // __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all + // lanes in the warp. Here it broadcasts the group-local warp index from lane + // 0. Comparing to 0 selects only the group's warp 0. + return __shfl_sync(0xffffffff, // full warp mask + (threadIdx.x / 32) % + (thread_extent / 32), // warp index within group + 0 // take the value from lane 0 + ) == 0 && + // Within that group leader warp, elect exactly one lane (typically + // lane 0) to be the single representative for the group. + cute::elect_one_sync(); +} + +template TL_DEVICE void warpgroup_reg_alloc() { + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount)); +} + +template TL_DEVICE void warpgroup_reg_dealloc() { + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount)); +} +#endif + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/ldsm.h b/tilelang/original/src/tl_templates/cuda/ldsm.h new file mode 100644 index 0000000000000000000000000000000000000000..4d6af8a0998723b8ba1155223818bfa692d012da --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/ldsm.h @@ -0,0 +1,121 @@ +#pragma once + +#include "common.h" + +namespace tl { + +TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(value[0]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(value[0]), "=r"(value[1]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(value[0]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(value[0]), "=r"(value[1]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_stmatrix_x1(void const *const smem_ptr, + const int32_t &value0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"( + smem_int_ptr), + "r"(value0)); +} + +TL_DEVICE void ptx_stmatrix_x2(void const *const smem_ptr, + const int32_t &value0, const int32_t &value1) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"( + smem_int_ptr), + "r"(value0), "r"(value1)); +} + +TL_DEVICE void ptx_stmatrix_x4(void const *const smem_ptr, + const int32_t &value0, const int32_t &value1, + const int32_t &value2, const int32_t &value3) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: + "r"(smem_int_ptr), + "r"(value0), "r"(value1), "r"(value2), "r"(value3)); +} + +TL_DEVICE void ptx_stmatrix_x1_trans(void const *const smem_ptr, + const int32_t &value0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"( + smem_int_ptr), + "r"(value0)); +} + +TL_DEVICE void ptx_stmatrix_x2_trans(void const *const smem_ptr, + const int32_t &value0, + const int32_t &value1) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"( + smem_int_ptr), + "r"(value0), "r"(value1)); +} + +TL_DEVICE void ptx_stmatrix_x4_trans(void const *const smem_ptr, + const int32_t &value0, + const int32_t &value1, + const int32_t &value2, + const int32_t &value3) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, " + "%3, %4};\n" ::"r"(smem_int_ptr), + "r"(value0), "r"(value1), "r"(value2), "r"(value3)); +} + +} // namespace tl \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/cuda/nvrtc_std.h b/tilelang/original/src/tl_templates/cuda/nvrtc_std.h new file mode 100644 index 0000000000000000000000000000000000000000..34cd58bb29994dbb1a50e7e9a1539bbfcc290380 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/nvrtc_std.h @@ -0,0 +1,176 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef __CUDACC_RTC__ + +// Disable problematic CUDA standard library headers in NVRTC environment +// Vector types (float4, uchar, etc.) are built-in to NVRTC and don't need these +// headers +#define _LIBCUDACXX___TUPLE_VECTOR_TYPES_H // Prevent vector_types.h inclusion + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using int32_t = signed int; +using uint32_t = unsigned int; +using int64_t = signed long long; +using uint64_t = unsigned long long; +using cuuint64_t = unsigned long long; + +#ifndef CU_TENSOR_MAP_NUM_QWORDS +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +}; + +using CUtensorMap = CUtensorMap_st; +#endif + +namespace std { + +template struct integral_constant { + static constexpr T value = v; + + using value_type = T; + using type = integral_constant; + + __device__ constexpr operator value_type() const noexcept { return value; } + + __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +template struct is_same : false_type {}; + +template struct is_same : true_type {}; + +template +inline constexpr bool is_same_v = is_same::value; + +template struct is_void : false_type {}; + +template <> struct is_void : true_type {}; +template <> struct is_void : true_type {}; +template <> struct is_void : true_type {}; +template <> struct is_void : true_type {}; + +template inline constexpr bool is_void_v = is_void::value; + +template struct is_pointer : false_type {}; + +template struct is_pointer : true_type {}; +template struct is_pointer : true_type {}; +template struct is_pointer : true_type {}; +template struct is_pointer : true_type {}; + +template inline constexpr bool is_pointer_v = is_pointer::value; + +namespace index_sequence_impl { + +// Based on https://stackoverflow.com/a/32223343/11717224 +template struct index_sequence { + using type = index_sequence; + using value_type = size_t; + static constexpr size_t size() noexcept { return sizeof...(Ints); } +}; + +template struct _merge_and_renumber; + +template +struct _merge_and_renumber, index_sequence> + : index_sequence {}; + +template +struct make_index_sequence + : _merge_and_renumber::type, + typename make_index_sequence::type> {}; + +template <> struct make_index_sequence<0> : index_sequence<> {}; +template <> struct make_index_sequence<1> : index_sequence<0> {}; + +} // namespace index_sequence_impl + +template +using index_sequence = index_sequence_impl::index_sequence; + +template +using make_index_sequence = index_sequence_impl::make_index_sequence; + +template constexpr T min(T a, T b) { return a < b ? a : b; } + +template constexpr T max(T a, T b) { return a > b ? a : b; } + +template struct conditional { + using type = T; +}; + +template struct conditional { + using type = F; +}; + +template +using conditional_t = typename conditional::type; + +template struct enable_if {}; + +template struct enable_if { + using type = T; +}; + +template struct remove_extent { + using type = T; +}; + +template struct remove_extent { + using type = T; +}; + +template struct remove_extent { + using type = T; +}; + +template using remove_extent_t = typename remove_extent::type; + +template +struct extent : integral_constant {}; + +template struct extent : integral_constant {}; + +template struct extent : extent {}; + +template +struct extent : integral_constant {}; + +template +struct extent : extent {}; + +template +inline constexpr size_t extent_v = extent::value; +} // namespace std + +#endif // __CUDACC_RTC__ diff --git a/tilelang/original/src/tl_templates/cuda/reduce.h b/tilelang/original/src/tl_templates/cuda/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..4582426493b5cbf1bc22605b7d571549b0e692fc --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/reduce.h @@ -0,0 +1,284 @@ +#pragma once + +#include "common.h" + +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +namespace tl { + +// Select a wider accumulator type for improved numerical accuracy. +// Default: accumulate in the same type. Specialize FP16/BF16 to float. +template struct AccType { + using type = T; +}; +template <> struct AccType { + using type = float; +}; +template <> struct AccType { + using type = float; +}; + +struct SumOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x + y; + } +}; + +struct MaxOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return cutlass::fast_max(x, y); + } +}; + +struct MinOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return cutlass::fast_min(x, y); + } +}; + +struct BitAndOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x & y; + } +}; + +struct BitOrOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x | y; + } +}; + +struct BitXorOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x ^ y; + } +}; + +template +struct AllReduce { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32 or + threads == 16 or threads == 8 or threads == 4 or threads == 2); + static_assert(threads % scale == 0); + template static TL_DEVICE T run(T x, T *red_buf = nullptr) { + constexpr int offset = threads / 2; + if constexpr (offset >= 32) { + __syncthreads(); + red_buf[threadIdx.x - thread_offset] = x; + __syncthreads(); + x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); + } else { + x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset)); + } + if constexpr (offset == scale) { + return x; + } else { + return AllReduce::run( + x, red_buf); + } + } + + template + static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) { + constexpr int offset = threads / 2; + if constexpr (offset >= 32) { + asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads)); + red_buf[threadIdx.x - thread_offset] = x; + // TODO(lei): maybe we can merge the two bar.sync into one? + asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); + x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); + } else { + x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset)); + } + if constexpr (offset == scale) { + return x; + } else { + return AllReduce::run_hopper(x, red_buf); + } + } +}; + +template struct CumSum1D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst, + int N) { + if (N <= 0) + return; + + constexpr unsigned MASK = 0xffffffff; + const int tid = threadIdx.x; + const int lane = tid % SEG; + + if (tid >= SEG) + return; + + T carry = (T)0; + + if (reverse) { + const int num_segments = (N + SEG - 1) / SEG; + for (int seg = num_segments - 1; seg >= 0; --seg) { + const int idx = seg * SEG + lane; + T val = (idx < N) ? src[idx] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)tl::shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (idx < N) + dst[idx] = val; + + T segSum = (T)__shfl_sync(MASK, val, 0); + if (lane == 0) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, 0); + } + } else { + const int num_segments = (N + SEG - 1) / SEG; + for (int seg = 0; seg < num_segments; ++seg) { + const int idx = seg * SEG + lane; + T val = (idx < N) ? src[idx] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (idx < N) + dst[idx] = val; + + T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, SEG - 1); + } + } + } +}; + +template struct CumSum2D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H, + int W) { + + constexpr int TILE_H = threads / SEG; + constexpr unsigned MASK = 0xffffffff; + const int num_blocks = (H + TILE_H - 1) / TILE_H; + const int tid = threadIdx.x; + const int lane = tid % 32; + const int row = tid / 32; + + for (int b = 0; b < num_blocks; ++b) { + const int gRow = b * TILE_H + row; + if (gRow >= H) + return; + + T carry = (T)0; + + if (reverse) { + // Start from the last segment for reverse mode + for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = tl::shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = tl::shfl_sync(MASK, val, 0); + if (lane == 0) + carry = segSum; + carry = tl::shfl_sync(MASK, carry, 0); + } + } else { + for (int seg = 0; seg * SEG < W; ++seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = tl::shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = tl::shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = tl::shfl_sync(MASK, carry, SEG - 1); + } + } + } + } +}; + +template +TL_DEVICE T warp_reduce(T value, ReduceOp op) { + constexpr uint32_t mask = 0xffffffff; + value = op(value, __shfl_xor_sync(mask, value, 16)); + value = op(value, __shfl_xor_sync(mask, value, 8)); + value = op(value, __shfl_xor_sync(mask, value, 4)); + value = op(value, __shfl_xor_sync(mask, value, 2)); + value = op(value, __shfl_xor_sync(mask, value, 1)); + return value; +} + +template TL_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, SumOp()); +} + +template TL_DEVICE T warp_reduce_max(T value) { + return warp_reduce(value, MaxOp()); +} + +template TL_DEVICE T warp_reduce_min(T value) { + return warp_reduce(value, MinOp()); +} + +template TL_DEVICE T warp_reduce_bitand(T value) { + return warp_reduce(value, BitAndOp()); +} + +template TL_DEVICE T warp_reduce_bitor(T value) { + return warp_reduce(value, BitOrOp()); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/tcgen_05.h b/tilelang/original/src/tl_templates/cuda/tcgen_05.h new file mode 100644 index 0000000000000000000000000000000000000000..e40907e3405cad52f1b50b9ec193a5df2afd6c16 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/tcgen_05.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "common.h" +#include + +namespace tl { + +TL_DEVICE void tmem_allocate(void *dst_ptr, int num_columns) { + uint32_t dst_intptr = smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); +} + +TL_DEVICE void tmem_deallocate(uint32_t *tmem_ptr, int num_columns) { + asm volatile("{\n\t" + "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(*tmem_ptr), "r"(num_columns)); +} + +inline void __device__ fence_view_async_tmem_load() { + asm volatile("tcgen05.wait::ld.sync.aligned; " ::); +} + +inline void __device__ fence_view_async_tmem_store() { + asm volatile("tcgen05.wait::st.sync.aligned; " ::); +} + +template +inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a, + uint64_t const desc_b, + uint32_t const tmem_c, + uint32_t const idesc, + uint32_t const addC = 1) { + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be " + "64 or 128 for 1 CTA cluster MMA."); + static_assert( + (M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, " + "%7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(idesc), "r"(addC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); +} + +// Wrapper for CUTLASS umma_arrive: elect one lane, then arrive the mbarrier +TL_DEVICE void tcgen05_mma_arrive(void const *smem_ptr) { + uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr); + if (cute::elect_one_sync()) { + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::" + "cluster.b64 [%0];" + : + : "r"(bar_intptr)); + } +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/tcgen_05_ld.h b/tilelang/original/src/tl_templates/cuda/tcgen_05_ld.h new file mode 100644 index 0000000000000000000000000000000000000000..9e5e34206a8139fd17616f10cf6d94272916b399 --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/tcgen_05_ld.h @@ -0,0 +1,1380 @@ +#pragma once + +#include +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "common.h" + +namespace tl { + +// 32 data path lanes, 32-bit pattern, repeated N times +template class tmem_ld_32dp32bNx; + +template <> class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 64-bit pattern, repeated N times +template class tmem_ld_16dp64bNx; +template <> class tmem_ld_16dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 128-bit pattern, repeated N times +template class tmem_ld_16dp128bNx; +template <> class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 256-bit pattern, repeated N times +template class tmem_ld_16dp256bNx; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 32 data path lanes, 64-bit pattern, repeated N times +// (conducted with 2x16dp64bNx) +template class tmem_ld_32dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + } +}; + +// 32 data path lanes, 128-bit pattern, repeated N times +template class tmem_ld_32dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + } +}; + +// 32 data path lanes, 256-bit pattern, repeated N times +template class tmem_ld_32dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + } +}; + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/cuda/threadblock_swizzle.h b/tilelang/original/src/tl_templates/cuda/threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..60fa0ad1f05f8907c2fc6d7663c760820a319d4e --- /dev/null +++ b/tilelang/original/src/tl_templates/cuda/threadblock_swizzle.h @@ -0,0 +1,43 @@ +#pragma once + +#include "common.h" + +namespace tl { + +template TL_DEVICE dim3 rasterization2DRow() { + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.x; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.x; + const unsigned int col_idx = (panel_idx & 1) + ? gridDim.x - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +template TL_DEVICE dim3 rasterization2DColumn() { + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.y; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.y; + const unsigned int row_idx = (panel_idx & 1) + ? gridDim.y - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/dcu_hip/common.h b/tilelang/original/src/tl_templates/dcu_hip/common.h new file mode 100644 index 0000000000000000000000000000000000000000..0aae7e9d3a470fecda91c9443bc8d2ca8fae1089 --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/common.h @@ -0,0 +1,140 @@ +#pragma once + +#include "core.hpp" +#include +#include +#include +// #include + +#define HIPRT_INF_F __int_as_float(0x7f800000) +#define HIPRT_NEGINF_F __int_as_float(0xff800000) +#define HIPRT_NAN_F __int_as_float(0x7fffffff) +#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001) +#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff) +#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000) +#define HIPRT_ZERO_F 0.0f +#define HIPRT_ONE_F 1.0f + +/* double precision constants */ +#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000) +#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000) + +#define uint unsigned int +#define uchar unsigned char +#define ushort unsigned short + +#define TL_DEVICE __forceinline__ __device__ +#define TL_DEVICE_NOINLINE __noinline__ __device__ + +#define TILELANG_CHECK(stmt) \ + do { \ + hipError_t __err = (stmt); \ + if (__err != hipSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \ + __LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +#define TILELANG_CHECK_LAST_ERROR(kernel_name) \ + do { \ + hipError_t __err = hipGetLastError(); \ + if (__err != hipSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \ + hipGetErrorName(__err), hipGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +#define half _Float16 +#define __float2half_rn(x) half(x) + +#define hpow __ocml_pown_f16 +#define hsqrt __ocml_sqrt_f16 + +using float16_t = _Float16; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; +using float16x8 = + __attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t; +using float16x16 = + __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t; + +using half_t = float16_t; + +using bfloat16_t = __hip_bfloat16; + +struct bfloat16x2 { + bfloat16_t x, y; +}; + +struct bfloat16x4 { + bfloat16_t data[4]; +}; + +struct bfloat16x8 { + bfloat16_t data[8]; +}; + +struct bfloat16x16 { + bfloat16_t data[16]; +}; + +typedef + __attribute__((__vector_size__(4 * sizeof(short)))) short bfloat16x4_vec; + +using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; +using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + +using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; + +// Pack two half_t values. +TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +template struct is_half_type : std::false_type {}; + +template <> struct is_half_type<__half> : std::true_type {}; + +template <> struct is_half_type : std::true_type {}; + +template +inline constexpr bool is_half_v = is_half_type>::value; + +template +TL_DEVICE void AtomicAdd(T1 *address, T2 val) { + if constexpr (is_half_v) { + __half *addr = reinterpret_cast<__half *>(address); + __half hval = __float2half(static_cast(val)); + atomicAdd(addr, hval); + } else { + atomicAdd(address, static_cast(val)); + } +} + +template TL_DEVICE void AtomicAdd(T1 &ref, T2 val) { + AtomicAdd(&ref, val); +} +template TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) { + return atomicAdd(&ref, static_cast(val)); +} + +template TL_DEVICE void AtomicAddx4(T *ref, const T val[4]) { + atomicAdd(&ref[0], val[0]); + atomicAdd(&ref[1], val[1]); + atomicAdd(&ref[2], val[2]); + atomicAdd(&ref[3], val[3]); +} \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/dcu_hip/copy.h b/tilelang/original/src/tl_templates/dcu_hip/copy.h new file mode 100644 index 0000000000000000000000000000000000000000..3ba334da88599516a9f6b98da2e5888bd92e5b86 --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/copy.h @@ -0,0 +1,110 @@ +#pragma once + +#include "common.h" + +using f32 = float; +// using f16 = _Float16; + +using u8 = std::uint8_t; +using u16 = std::uint16_t; +using u32 = std::uint32_t; + +using index_t = u32; + +using ck_tile::int32x4_t; + +struct __attribute__((packed)) buffer_resource { + const void *ptr; + uint32_t range; + uint32_t config; +}; + +CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, + uint32_t size = 0xffffffff) { + buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; + int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); + return r; +} + +__device__ void init_m0(uint32_t m0_value) { + asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory"); +} + +__device__ void inc_m0(uint32_t m0_inc) { + asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory"); +} + +namespace tl { + +// AMDGPU automatically commit memory fence +TL_DEVICE void cp_async_commit() {} + +// Global Memory only fence +__device__ void async_gld_fence(index_t cnt) { + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// Global Memory and Shared Memory fence +__device__ void async_gld_sld_fence(index_t cnt) { + asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory"); +} + +__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); } + +template TL_DEVICE void cp_async_wait() { + async_gld_fence(N); + // or + // async_gld_sld_fence(N); +} + +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc, + index_t voffset) { + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(smem))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(voffset), "s"(rsrc) + : "memory"); +} + +template +TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) { + if constexpr (N == 16) { + *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr; + } else if constexpr (N == 8) { + *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr; + } else if constexpr (N == 4) { + async_buffer_load_dword_v( + lds_base_ptr, + make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), + threadIdx.x * N /*assume 4 bytes*/); + } +} + +template +TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, + void *global_base_ptr, bool cond) { + if constexpr (N == 16) { + *(uint4 *)lds_base_ptr = + cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0); + } else if constexpr (N == 8) { + *(uint2 *)lds_base_ptr = + cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0); + } else { + if (cond) { + async_buffer_load_dword_v( + lds_base_ptr, + make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), + threadIdx.x * N /*assume 4 bytes*/); + } else { + *(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0); + } + } +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/dcu_hip/core.hpp b/tilelang/original/src/tl_templates/dcu_hip/core.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d718622f53fc1708687353dad70b1647416f8c5 --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/core.hpp @@ -0,0 +1,77 @@ +#ifdef __HIPCC__ +#define CK_TILE_HOST inline __host__ +#define CK_TILE_DEVICE inline __device__ +#define CK_TILE_HOST_DEVICE inline __host__ __device__ +#define CK_TILE_DEVICE_EXTERN __device__ +#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__ +#else +#define CK_TILE_HOST inline +#define CK_TILE_DEVICE inline +#define CK_TILE_HOST_DEVICE inline +#define CK_TILE_DEVICE_EXTERN +#define CK_TILE_HOST_DEVICE_EXTERN +#endif + +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \ + defined(__gfx9__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx103__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 +#else +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 +#endif + +namespace ck_tile { +using int32x4_t = int32_t __attribute__((ext_vector_type(4))); +template CK_TILE_HOST_DEVICE constexpr T max(T x) { return x; } + +template CK_TILE_HOST constexpr T max(T x, T y) { + return x > y ? x : y; +} + +template CK_TILE_DEVICE constexpr T max(T x, T y) { + return x > y ? x : y; +} + +template <> CK_TILE_DEVICE float max(float x, float y) { + return __builtin_fmaxf(x, y); // can resultin v_max3_f32 +} + +template <> CK_TILE_DEVICE double max(double x, double y) { + return __builtin_fmax(x, y); // maybe still v_max3_f32 +} + +template +CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys) { + static_assert(sizeof...(Ys) > 0, "not enough argument"); + return max(x, max(ys...)); +} + +template CK_TILE_HOST_DEVICE constexpr T min(T x) { return x; } + +template CK_TILE_HOST constexpr T min(T x, T y) { + return x < y ? x : y; +} + +template CK_TILE_DEVICE constexpr T min(T x, T y) { + return x < y ? x : y; +} + +template <> CK_TILE_DEVICE float min(float x, float y) { + return __builtin_fminf(x, y); +} + +template <> CK_TILE_DEVICE double min(double x, double y) { + return __builtin_fmin(x, y); +} + +template +CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys) { + static_assert(sizeof...(Ys) > 0, "not enough argument"); + return min(x, min(ys...)); +} +} // namespace ck_tile diff --git a/tilelang/original/src/tl_templates/dcu_hip/debug.h b/tilelang/original/src/tl_templates/dcu_hip/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..7b19d3e943719883a0defd1d5d422e40b9b6998c --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/debug.h @@ -0,0 +1,191 @@ +#pragma once +#include + +// Base template declaration +template __device__ void debug_print_var(const char *msg, T var); + +// Specialization for signed char type +template <> +__device__ void debug_print_var(const char *msg, signed char var) { + const char *safe_msg = msg; + int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed " + "char value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Specialization for unsigned char type +template <> +__device__ void debug_print_var(const char *msg, + unsigned char var) { + const char *safe_msg = msg; + unsigned int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned char value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Specialization for int type +template <> __device__ void debug_print_var(const char *msg, int var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " + "value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for unsigned int type +template <> +__device__ void debug_print_var(const char *msg, + unsigned int var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned int value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for float type +template <> __device__ void debug_print_var(const char *msg, float var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " + "value=%f\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for double type +template <> +__device__ void debug_print_var(const char *msg, double var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double " + "value=%lf\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for bool type +template <> __device__ void debug_print_var(const char *msg, bool var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool " + "value=%s\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, + var ? "true" : "false"); +} + +// Specialization for short type +template <> __device__ void debug_print_var(const char *msg, short var) { + const char *safe_msg = msg; + int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=short " + "value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Specialization for unsigned short type +template <> +__device__ void debug_print_var(const char *msg, + unsigned short var) { + const char *safe_msg = msg; + unsigned int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned short value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Template declaration for device-side debug printing (buffer only) +template +__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, + int index, T var); + +// Specialization for signed char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, signed char var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=signed char value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, value); +} + +// Specialization for unsigned char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, unsigned char var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + unsigned int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=unsigned char value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, value); +} + +// Specialization for integer type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + int var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=int value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, var); +} + +// Specialization for float type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + float var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=float value=%f\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, var); +} + +// Specialization for half_t type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, half_t var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + float value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=half_t value=%f\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, value); +} + +// Specialization for double type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, double var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=double value=%lf\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, var); +} diff --git a/tilelang/original/src/tl_templates/dcu_hip/gemm.h b/tilelang/original/src/tl_templates/dcu_hip/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..1d6a9d03b41e34e8dba0700f50e3d605e356fd8b --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/gemm.h @@ -0,0 +1,323 @@ +#pragma once + +#include "common.h" +#include + +namespace tl { + +// Trait to determine the MFMA instruction to use based on data type +template struct MfmaTraits; + +// Specialization for int8 +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) { + int64_t *b_packed = reinterpret_cast(const_cast(b)); + int64_t *a_packed = reinterpret_cast(const_cast(a)); + + *c = __builtin_amdgcn_mmac_i32_16x16x32i8(*b_packed, *a_packed, *c); + } +}; + +// Specialization for half/float16 +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const half *b, const half *a, AccType *c) { + *c = __builtin_amdgcn_mmac_f32_16x16x16f16(*((float16x4 *)b), + *((float16x4 *)a), *c); + } +}; + +// Specialization for bfloat16_t +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a, + AccType *c) { + bfloat16x4_vec b_vec, a_vec; + + // Reinterpret the pointers + short *b_short = reinterpret_cast(const_cast(b)); + short *a_short = reinterpret_cast(const_cast(a)); + + // Copy the data + for (int i = 0; i < 4; ++i) { + b_vec[i] = b_short[i]; + a_vec[i] = a_short[i]; + } + + // Call the intrinsic and store the result directly to c + *c = __builtin_amdgcn_mmac_f32_16x16x16bf16(b_vec, a_vec, *c); + } +}; + +#if defined(HIP_FP8_ENABLED) +// Specialization for fp8_e4_t +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a, + AccType *c) { + int64_t a_val = *reinterpret_cast(a); + int64_t b_val = *reinterpret_cast(b); + *c = __builtin_amdgcn_mmac_f32_16x16x32_fp8_fp8(b_val, a_val, *c); + } +}; +#endif + +// ref to bitblas/tl/mfma_macro_generator.py::kPack +template +class GemmTensorOp { +public: + // static_assert(!clear_accum, "clear_accum=true is not supported yet"); + + static constexpr int micro_size_x = 16; + static constexpr int micro_size_y = 16; + static constexpr int micro_size_k = 32 / sizeof(A_type); + static constexpr int vec_size = 8 / sizeof(A_type); + + // This part comes from the Codegen + static constexpr int M_Tile = N; + static constexpr int N_Tile = M; + static constexpr int K_Tile = K; + + static constexpr int block_row_warps = num_warp_m; + static constexpr int block_col_warps = num_warp_n; + + static constexpr int inner_k = K_Tile / (micro_size_k * kPack); + static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x); + static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y); + + // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen + // part. + static constexpr bool kPadA = true; + static constexpr bool kPadB = true; + static constexpr bool kPadC = true; + + static constexpr int BANK_SIZE_BYTES = 128; + + static constexpr int warp_size = 64; + + TL_DEVICE static constexpr auto reverse_index_map(int thread_id, + int local_id) { + return std::make_pair(thread_id % 16, + (thread_id / 16) * (vec_size * kPack) + local_id); + } + + TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id, + int local_id) { + return std::make_pair((thread_id / 16) * (vec_size * kPack) + local_id, + thread_id % 16); + } + + /* + * Detailed Implementation please + * checkout bitblas/tl/utils.py:get_swizzle_layout + */ + template + TL_DEVICE static auto make_mfma_swizzle_layout(const int row, const int col) { + const auto dtype_bits = element_size * 8; + + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + const int vecSize = vec_size * kPack; + const int innerDimLength = continuous; + const int typeWidthInBit = dtype_bits; + + const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + const int maxPhase = + std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + + const int phase = (row / perPhase) % maxPhase; + const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize); + const int colOffOrdered = col % vecSize; + const int colOff = colOffSwizzled + colOffOrdered; + + return std::make_pair(row, colOff); + } + + template + TL_DEVICE static constexpr auto make_layout_padded(const int row, + const int col) { + return std::make_pair(row, col); + } + + template + TL_DEVICE static constexpr auto make_swizzle_layout(const int row, + const int col) { + auto [n_row, n_col] = + make_mfma_swizzle_layout(row, col); + return n_row * continuous + n_col; + } + + static TL_DEVICE void body(A_type *A_shared, B_type *B_shared, + C_type *C_local) { + auto tid = threadIdx.x; + auto warp_id = tid / warp_size; + auto warp_m = warp_id / block_col_warps; + auto warp_n = warp_id % block_col_warps; + auto warp_row_tiles = warp_rows * micro_size_x; + auto warp_col_tiles = warp_cols * micro_size_y; + + auto lane_id = tid % warp_size; + auto tx = lane_id; + + auto alane_id = lane_id; + auto blane_id = + ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4); + + constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; + constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; + constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size; + + constexpr auto last_dim_b = TransposeB ? K_Tile : M_Tile; + constexpr auto last_dim_a = TransposeA ? N_Tile : K_Tile; + + B_type B_local[warp_rows * kPack * local_size_b]; + A_type A_local[warp_cols * kPack * local_size_a]; + + for (int ki = 0; ki < inner_k; ki++) { + // Fetch B into register + for (int i = 0; i < warp_rows; i++) { + const auto l = warp_m * warp_row_tiles + i * micro_size_x; + const auto r = ki * (kPack * micro_size_k); + for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { + if constexpr (TransposeB) { + auto [row, col] = reverse_index_map(blane_id, local_id); + B_local[i * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + l + row, r + col)]; + } else { + auto [row, col] = reverse_index_map_transposed(blane_id, local_id); + B_local[i * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + r + row, l + col)]; + } + } + } + // Fetch A into register + for (int j = 0; j < warp_cols; j++) { + const auto l = warp_n * warp_col_tiles + j * micro_size_y; + const auto r = ki * (kPack * micro_size_k); + for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { + if constexpr (TransposeA) { + auto [row, col] = reverse_index_map_transposed(alane_id, local_id); + A_local[j * kPack * local_size_a + local_id] = + A_shared[make_swizzle_layout( + r + row, l + col)]; + } else { + auto [row, col] = reverse_index_map(alane_id, local_id); + A_local[j * kPack * local_size_a + local_id] = + A_shared[make_swizzle_layout( + l + row, r + col)]; + } + } + } + // Compute + for (int kp = 0; kp < kPack; kp++) { + for (int i = 0; i < warp_rows; ++i) { + for (int j = 0; j < warp_cols; ++j) { + auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); + auto a_ptr = ((A_type *)A_local) + (j * kPack + kp) * vec_size; + auto b_ptr = ((B_type *)B_local) + (i * kPack + kp) * vec_size; + + // Use the trait to select the correct MFMA instruction, either fp8, + // fp16 or bf16 currently + MfmaTraits::mfma_op(a_ptr, b_ptr, acc_ptr); + } + } + } + } + } + + static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared, + C_type *C_local) { + auto tid = threadIdx.x; + auto warp_id = tid / warp_size; + auto warp_m = warp_id / block_col_warps; + auto warp_n = warp_id % block_col_warps; + auto warp_row_tiles = warp_rows * micro_size_x; + auto warp_col_tiles = warp_cols * micro_size_y; + + auto lane_id = tid % warp_size; + auto tx = lane_id; + + auto alane_id = lane_id; + auto blane_id = + ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4); + + constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; + constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; + constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size; + + constexpr auto last_dim_b = TransposeB ? K_Tile : M_Tile; + constexpr auto last_dim_a = TransposeA ? N_Tile : K_Tile; + + B_type B_local[warp_rows * kPack * local_size_b]; + + for (int ki = 0; ki < inner_k; ki++) { + // Fetch B into register + for (int i = 0; i < warp_rows; i++) { + const auto l = warp_m * warp_row_tiles + i * micro_size_x; + const auto r = ki * (kPack * micro_size_k); + for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { + if constexpr (TransposeB) { + auto [row, col] = reverse_index_map(blane_id, local_id); + B_local[i * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + l + row, r + col)]; + } else { + auto [row, col] = reverse_index_map_transposed(blane_id, local_id); + B_local[i * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + r + row, l + col)]; + } + } + } + + // Compute + for (int kp = 0; kp < kPack; kp++) { + for (int i = 0; i < warp_rows; ++i) { + for (int j = 0; j < warp_cols; ++j) { + auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); + auto b_ptr = ((B_type *)B_local) + (i * kPack + kp) * vec_size; + auto a_ptr = ((A_type *)A_local) + + (ki * warp_cols * kPack + j * kPack + kp) * vec_size; + + // Use the trait to select the correct MFMA instruction, either fp8, + // fp16 or bf16 currently + MfmaTraits::mfma_op(a_ptr, b_ptr, acc_ptr); + } + } + } + } + } +}; + +} // namespace tl + +namespace tl { + +template +TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + using Compute = + GemmTensorOp; + Compute::body(pA, pB, accum); +} + +template +TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + using Compute = + GemmTensorOp; + Compute::body_rs(pA, pB, accum); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/dcu_hip/hip_fp8.h b/tilelang/original/src/tl_templates/dcu_hip/hip_fp8.h new file mode 100644 index 0000000000000000000000000000000000000000..96eb6844d6a31b0ead960803c103e3bd35045f28 --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/hip_fp8.h @@ -0,0 +1,74 @@ +#include + +#define HIP_FP8_ENABLED 1 + +using fp8_e4_t = __hip_fp8_e4m3_fnuz; +using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz; + +// Simple wrapper that provides member access for generated code +struct fp8_e4_4_t { + union { + __hip_fp8x4_e4m3_fnuz data; + struct { + fp8_e4_t x, y, z, w; + }; + }; + + // Default constructor + __device__ fp8_e4_4_t() = default; + + // Constructor from __hip_fp8x4_e4m3_fnuz + __device__ fp8_e4_4_t(const __hip_fp8x4_e4m3_fnuz &val) : data(val) {} + + // Constructor from float4 + __device__ fp8_e4_4_t(const float4 &val) : data(val) {} + + // Conversion operator to __hip_fp8x4_e4m3_fnuz + __device__ operator __hip_fp8x4_e4m3_fnuz() const { return data; } + + // Assignment operator + __device__ fp8_e4_4_t &operator=(const __hip_fp8x4_e4m3_fnuz &val) { + data = val; + return *this; + } +}; + +struct __align__(8) fp8_e4_8_t { + fp8_e4_4_t x; + fp8_e4_4_t y; +}; + +struct __align__(16) fp8_e4_16_t { + fp8_e4_8_t x; + fp8_e4_8_t y; +}; + +__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, + fp8_e4_t w) { + // reinterpret the 4 fp8_e4_t values to signed char value and shift + signed char x_char = *reinterpret_cast(&x); + signed char y_char = *reinterpret_cast(&y); + signed char z_char = *reinterpret_cast(&z); + signed char w_char = *reinterpret_cast(&w); + int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; + return *reinterpret_cast(&res); +} + +__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, + fp8_e4_t w, fp8_e4_t v, fp8_e4_t u, + fp8_e4_t t, fp8_e4_t s) { + signed char x_char = *reinterpret_cast(&x); + signed char y_char = *reinterpret_cast(&y); + signed char z_char = *reinterpret_cast(&z); + signed char w_char = *reinterpret_cast(&w); + signed char v_char = *reinterpret_cast(&v); + signed char u_char = *reinterpret_cast(&u); + signed char t_char = *reinterpret_cast(&t); + signed char s_char = *reinterpret_cast(&s); + int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; + int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char; + fp8_e4_8_t res; + res.x = *reinterpret_cast(&a); + res.y = *reinterpret_cast(&b); + return res; +} diff --git a/tilelang/original/src/tl_templates/dcu_hip/ldsm.h b/tilelang/original/src/tl_templates/dcu_hip/ldsm.h new file mode 100644 index 0000000000000000000000000000000000000000..286b77324262e04300b0240983d21036fa85125b --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/ldsm.h @@ -0,0 +1,3 @@ +#pragma once + +#include "common.h" diff --git a/tilelang/original/src/tl_templates/dcu_hip/reduce.h b/tilelang/original/src/tl_templates/dcu_hip/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..728579338b584419c5c8e3cfe5a74b0eeb15e5d1 --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/reduce.h @@ -0,0 +1,169 @@ +#pragma once + +#include "common.h" + +namespace tl { + +struct SumOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x + y; + } +}; + +struct MaxOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return ck_tile::max(x, y); + } +}; + +struct MinOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return ck_tile::min(x, y); + } +}; +// Detect half types +template struct is_half_type : std::false_type {}; + +template <> struct is_half_type<__half> : std::true_type {}; + +template <> struct is_half_type<_Float16> : std::true_type {}; + +template +inline constexpr bool is_half_v = is_half_type>::value; + +template +struct AllReduce { + static_assert(threads == 1024 || threads == 512 || threads == 256 || + threads == 128 || threads == 64 || threads == 32 || + threads == 16 || threads == 8 || threads == 4 || threads == 2); + static_assert(threads % scale == 0); + + template static __device__ T run(T x, T *red_buf = nullptr) { + constexpr int offset = threads / 2; + constexpr int warpSize = 64; + + if constexpr (offset >= warpSize) { + __syncthreads(); + red_buf[threadIdx.x] = x; + __syncthreads(); + x = Reducer()(x, red_buf[threadIdx.x ^ offset]); + } else { + if constexpr (is_half_v) { + unsigned short x_raw; + if constexpr (std::is_same_v, __half>) { + x_raw = __half_as_ushort(x); + } else { // _Float16 + union { + _Float16 f; + unsigned short s; + } u; + u.f = x; + x_raw = u.s; + } + + unsigned short shuffled_raw = __shfl_xor(x_raw, offset); + + T shuffled_x; + if constexpr (std::is_same_v, __half>) { + shuffled_x = __ushort_as_half(shuffled_raw); + } else { // _Float16 + union { + unsigned short s; + _Float16 f; + } u; + u.s = shuffled_raw; + shuffled_x = u.f; + } + + x = Reducer()(x, shuffled_x); + } else { + x = Reducer()(x, __shfl_xor(x, offset)); + } + } + + if constexpr (offset == scale) { + return x; + } else { + return AllReduce::run(x, red_buf); + } + } +}; +template struct CumSum2D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H, + int W) { + + constexpr int TILE_H = threads / SEG; + constexpr uint64_t MASK = 0xffffffffffffffffULL; + const int num_blocks = (H + TILE_H - 1) / TILE_H; + const int tid = threadIdx.x; + const int lane = tid % 64; + const int row = tid / 64; + + for (int b = 0; b < num_blocks; ++b) { + const int gRow = b * TILE_H + row; + if (gRow >= H) + return; + + T carry = (T)0; + + if (reverse) { + // Start from the last segment for reverse mode + for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = (T)__shfl_sync(MASK, val, (T)0); + if (lane == 0) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, (T)0); + } + } else { + for (int seg = 0; seg * SEG < W; ++seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, SEG - 1); + } + } + } + } +}; +} // namespace tl diff --git a/tilelang/original/src/tl_templates/dcu_hip/threadblock_swizzle.h b/tilelang/original/src/tl_templates/dcu_hip/threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..7771f0b98598157a51a0e992fa06c9c625fb0be7 --- /dev/null +++ b/tilelang/original/src/tl_templates/dcu_hip/threadblock_swizzle.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common.h" + +namespace tl { + +template TL_DEVICE dim3 rasterization2DRow() { + auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.x; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.x; + const unsigned int col_idx = (panel_idx & 1) + ? gridDim.x - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +template TL_DEVICE dim3 rasterization2DColumn() { + auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.y; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.y; + const unsigned int row_idx = (panel_idx & 1) + ? gridDim.y - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/hip/common.h b/tilelang/original/src/tl_templates/hip/common.h new file mode 100644 index 0000000000000000000000000000000000000000..8be247e77e85a53eb5affaedf1178c6e1a95bb4e --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/common.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include +#include +#include + +#define HIPRT_INF_F __int_as_float(0x7f800000) +#define HIPRT_NEGINF_F __int_as_float(0xff800000) +#define HIPRT_NAN_F __int_as_float(0x7fffffff) +#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001) +#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff) +#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000) +#define HIPRT_ZERO_F 0.0f +#define HIPRT_ONE_F 1.0f + +/* double precision constants */ +#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000) +#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000) + +#define uint unsigned int +#define uchar unsigned char +#define ushort unsigned short + +#define TL_DEVICE __forceinline__ __device__ +#define TL_DEVICE_NOINLINE __noinline__ __device__ + +#define TILELANG_CHECK(stmt) \ + do { \ + hipError_t __err = (stmt); \ + if (__err != hipSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \ + __LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +#define TILELANG_CHECK_LAST_ERROR(kernel_name) \ + do { \ + hipError_t __err = hipGetLastError(); \ + if (__err != hipSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \ + hipGetErrorName(__err), hipGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +#define half _Float16 +#define __float2half_rn(x) half(x) + +#define hpow __ocml_pown_f16 +#define hsqrt __ocml_sqrt_f16 + +using float16_t = _Float16; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; +using float16x8 = + __attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t; +using float16x16 = + __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t; + +using half_t = float16_t; + +using bfloat16_t = hip_bfloat16; + +struct bfloat16x2 { + bfloat16_t x, y; +}; + +struct bfloat16x4 { + bfloat16_t data[4]; +}; + +struct bfloat16x8 { + bfloat16_t data[8]; +}; + +struct bfloat16x16 { + bfloat16_t data[16]; +}; + +typedef + __attribute__((__vector_size__(4 * sizeof(short)))) short bfloat16x4_vec; + +using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; +using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + +using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; + +// Pack two half_t values. +TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +template +TL_DEVICE void AtomicAdd(T1 *address, T2 val) { + atomicAdd(reinterpret_cast(address), static_cast(val)); +} + +// Overload for when the first argument is a value instead of a pointer +template +TL_DEVICE void AtomicAdd(T1 address, T2 val) { + atomicAdd(reinterpret_cast(&address), static_cast(val)); +} + +template +TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val) { + return atomicAdd(reinterpret_cast(address), static_cast(val)); +} diff --git a/tilelang/original/src/tl_templates/hip/copy.h b/tilelang/original/src/tl_templates/hip/copy.h new file mode 100644 index 0000000000000000000000000000000000000000..3f122d801f80b77cd21826e839033d6acdf11f34 --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/copy.h @@ -0,0 +1,112 @@ +#pragma once + +#include "common.h" + +using f32 = float; +// using f16 = _Float16; + +using u8 = std::uint8_t; +using u16 = std::uint16_t; +using u32 = std::uint32_t; + +using index_t = u32; + +using ck_tile::int32x4_t; + +struct __attribute__((packed)) buffer_resource { + const void *ptr; + uint32_t range; + uint32_t config; +}; + +CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, + uint32_t size = 0xffffffff) { + buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; + int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); + return r; +} + +__device__ void init_m0(uint32_t m0_value) { + asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory"); +} + +__device__ void inc_m0(uint32_t m0_inc) { + asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory"); +} + +namespace tl { + +// AMDGPU automatically commit memory fence +TL_DEVICE void cp_async_commit() {} + +// Global Memory only fence +__device__ void async_gld_fence(index_t cnt) { + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// Global Memory and Shared Memory fence +__device__ void async_gld_sld_fence(index_t cnt) { + asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory"); +} + +__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); } + +template TL_DEVICE void cp_async_wait() { + async_gld_fence(N); + // or + // async_gld_sld_fence(N); +} + +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc, + index_t voffset) { + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(smem))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(voffset), "s"(rsrc) + : "memory"); +} + +template +TL_DEVICE void cp_async_gs(void *lds_base_ptr, void const *global_base_ptr) { + if constexpr (N == 16) { + *(uint4 *)lds_base_ptr = *(const uint4 *)global_base_ptr; + } else if constexpr (N == 8) { + *(uint2 *)lds_base_ptr = *(const uint2 *)global_base_ptr; + } else if constexpr (N == 4) { + async_buffer_load_dword_v( + lds_base_ptr, + make_wave_buffer_resource(((const int32_t *)global_base_ptr) - + threadIdx.x), + threadIdx.x * N /*assume 4 bytes*/); + } +} + +template +TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, + void const *global_base_ptr, bool cond) { + if constexpr (N == 16) { + *(uint4 *)lds_base_ptr = + cond ? *(const uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0); + } else if constexpr (N == 8) { + *(uint2 *)lds_base_ptr = + cond ? *(const uint2 *)global_base_ptr : make_uint2(0, 0); + } else { + if (cond) { + async_buffer_load_dword_v( + lds_base_ptr, + make_wave_buffer_resource(((const int32_t *)global_base_ptr) - + threadIdx.x), + threadIdx.x * N /*assume 4 bytes*/); + } else { + *(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0); + } + } +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/hip/debug.h b/tilelang/original/src/tl_templates/hip/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..7b19d3e943719883a0defd1d5d422e40b9b6998c --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/debug.h @@ -0,0 +1,191 @@ +#pragma once +#include + +// Base template declaration +template __device__ void debug_print_var(const char *msg, T var); + +// Specialization for signed char type +template <> +__device__ void debug_print_var(const char *msg, signed char var) { + const char *safe_msg = msg; + int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed " + "char value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Specialization for unsigned char type +template <> +__device__ void debug_print_var(const char *msg, + unsigned char var) { + const char *safe_msg = msg; + unsigned int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned char value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Specialization for int type +template <> __device__ void debug_print_var(const char *msg, int var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " + "value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for unsigned int type +template <> +__device__ void debug_print_var(const char *msg, + unsigned int var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned int value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for float type +template <> __device__ void debug_print_var(const char *msg, float var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " + "value=%f\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for double type +template <> +__device__ void debug_print_var(const char *msg, double var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double " + "value=%lf\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var); +} + +// Specialization for bool type +template <> __device__ void debug_print_var(const char *msg, bool var) { + const char *safe_msg = msg; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool " + "value=%s\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, + var ? "true" : "false"); +} + +// Specialization for short type +template <> __device__ void debug_print_var(const char *msg, short var) { + const char *safe_msg = msg; + int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=short " + "value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Specialization for unsigned short type +template <> +__device__ void debug_print_var(const char *msg, + unsigned short var) { + const char *safe_msg = msg; + unsigned int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned short value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value); +} + +// Template declaration for device-side debug printing (buffer only) +template +__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, + int index, T var); + +// Specialization for signed char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, signed char var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=signed char value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, value); +} + +// Specialization for unsigned char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, unsigned char var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + unsigned int value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=unsigned char value=%u\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, value); +} + +// Specialization for integer type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + int var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=int value=%d\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, var); +} + +// Specialization for float type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + float var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=float value=%f\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, var); +} + +// Specialization for half_t type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, half_t var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + float value = static_cast(var); + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=half_t value=%f\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, value); +} + +// Specialization for double type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, double var) { + const char *safe_msg = msg; + const char *safe_buf_name = buf_name; + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=double value=%lf\n", + safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z, + (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name, + index, var); +} diff --git a/tilelang/original/src/tl_templates/hip/gemm.h b/tilelang/original/src/tl_templates/hip/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..068d57a64c3d10f3e0f4dd21fcfebfbd090ffbb3 --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/gemm.h @@ -0,0 +1,318 @@ +#pragma once + +#include "common.h" +#include + +namespace tl { + +// Trait to determine the MFMA instruction to use based on data type +template struct MfmaTraits; + +// Specialization for int8 +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) { + int64_t *b_packed = reinterpret_cast(const_cast(b)); + int64_t *a_packed = reinterpret_cast(const_cast(a)); + + *c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0, + 0); + } +}; + +// Specialization for half/float16 +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const half *b, const half *a, AccType *c) { + *c = __builtin_amdgcn_mfma_f32_16x16x16f16(*((float16x4 *)b), + *((float16x4 *)a), *c, 0, 0, 0); + } +}; + +// Specialization for bfloat16_t +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a, + AccType *c) { + bfloat16x4_vec b_vec, a_vec; + + // Reinterpret the pointers + short *b_short = reinterpret_cast(const_cast(b)); + short *a_short = reinterpret_cast(const_cast(a)); + + // Copy the data + for (int i = 0; i < 4; ++i) { + b_vec[i] = b_short[i]; + a_vec[i] = a_short[i]; + } + + // Call the intrinsic and store the result directly to c + *c = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(b_vec, a_vec, *c, 0, 0, 0); + } +}; + +#if defined(HIP_FP8_ENABLED) +// Specialization for fp8_e4_t +template <> struct MfmaTraits { + template + static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a, + AccType *c) { + int64_t a_val = *reinterpret_cast(a); + int64_t b_val = *reinterpret_cast(b); + *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0); + } +}; +#endif + +// ref to bitblas/tl/mfma_macro_generator.py::kPack +template +class GemmTensorOp { +public: + // Note: clear_accum=true is not fully supported in HIP implementation + // but we'll handle it by manually clearing the accumulator + // static_assert(!clear_accum, "clear_accum=true is not supported yet"); + + static constexpr int micro_size_x = 16; + static constexpr int micro_size_y = 16; + static constexpr int micro_size_k = 32 / sizeof(A_type); + static constexpr int vec_size = 8 / sizeof(A_type); + + // This part comes from the Codegen + static constexpr int M_Tile = M; + static constexpr int N_Tile = N; + static constexpr int K_Tile = K; + + static constexpr int block_row_warps = num_warp_m; + static constexpr int block_col_warps = num_warp_n; + + static constexpr int inner_k = K_Tile / (micro_size_k * kPack); + static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x); + static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y); + + // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen + // part. + static constexpr bool kPadA = true; + static constexpr bool kPadB = true; + static constexpr bool kPadC = true; + + static constexpr int BANK_SIZE_BYTES = 128; + + static constexpr int warp_size = 64; + + TL_DEVICE static constexpr auto reverse_index_map(int thread_id, + int local_id) { + return std::make_pair(thread_id % 16, + (thread_id / 16) * (vec_size * kPack) + local_id); + } + + TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id, + int local_id) { + return std::make_pair((thread_id / 16) * (vec_size * kPack) + local_id, + thread_id % 16); + } + + /* + * Detailed Implementation please + * checkout bitblas/tl/utils.py:get_swizzle_layout + */ + template + TL_DEVICE static auto make_mfma_swizzle_layout(const int row, const int col) { + const auto dtype_bits = element_size * 8; + + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + const int vecSize = vec_size * kPack; + const int innerDimLength = continuous; + const int typeWidthInBit = dtype_bits; + + const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + const int maxPhase = + std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + + const int phase = (row / perPhase) % maxPhase; + const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize); + const int colOffOrdered = col % vecSize; + const int colOff = colOffSwizzled + colOffOrdered; + + return std::make_pair(row, colOff); + } + + template + TL_DEVICE static constexpr auto make_layout_padded(const int row, + const int col) { + return std::make_pair(row, col); + } + + template + TL_DEVICE static constexpr auto make_swizzle_layout(const int row, + const int col) { + auto [n_row, n_col] = + make_mfma_swizzle_layout(row, col); + return n_row * continuous + n_col; + } + + static TL_DEVICE void body(A_type *A_shared, B_type *B_shared, + C_type *C_local) { + auto tid = threadIdx.x; + auto warp_id = tid / warp_size; + auto warp_n = warp_id / block_row_warps; + auto warp_m = warp_id % block_row_warps; + auto warp_row_tiles = warp_rows * micro_size_x; + auto warp_col_tiles = warp_cols * micro_size_y; + + auto lane_id = tid % warp_size; + auto tx = lane_id; + + constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; + constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; + constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size; + + constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile; + constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile; + + A_type A_local[warp_rows * kPack * local_size_a]; + B_type B_local[warp_cols * kPack * local_size_b]; + + for (int ki = 0; ki < inner_k; ki++) { + // Fetch A into register + for (int i = 0; i < warp_rows; i++) { + const auto l = warp_m * warp_row_tiles + i * micro_size_x; + const auto r = ki * (kPack * micro_size_k); + for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { + if constexpr (TransposeA) { + auto [row, col] = reverse_index_map_transposed(lane_id, local_id); + A_local[i * kPack * local_size_a + local_id] = + A_shared[make_swizzle_layout( + r + row, l + col)]; + } else { + auto [row, col] = reverse_index_map(lane_id, local_id); + A_local[i * kPack * local_size_a + local_id] = + A_shared[make_swizzle_layout( + l + row, r + col)]; + } + } + } + // Fetch B into register + for (int j = 0; j < warp_cols; j++) { + const auto l = warp_n * warp_col_tiles + j * micro_size_y; + const auto r = ki * (kPack * micro_size_k); + for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { + if constexpr (TransposeB) { + auto [row, col] = reverse_index_map(lane_id, local_id); + B_local[j * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + l + row, r + col)]; + } else { + auto [row, col] = reverse_index_map_transposed(lane_id, local_id); + B_local[j * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + r + row, l + col)]; + } + } + } + // Compute + for (int kp = 0; kp < kPack; kp++) { + for (int i = 0; i < warp_rows; ++i) { + for (int j = 0; j < warp_cols; ++j) { + auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); + auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * vec_size; + auto a_ptr = ((A_type *)A_local) + (i * kPack + kp) * vec_size; + + // Use the trait to select the correct MFMA instruction, either fp8, + // fp16 or bf16 currently + MfmaTraits::mfma_op(b_ptr, a_ptr, acc_ptr); + } + } + } + } + } + + static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared, + C_type *C_local) { + auto tid = threadIdx.x; + auto warp_id = tid / warp_size; + auto warp_n = warp_id / block_row_warps; + auto warp_m = warp_id % block_row_warps; + auto warp_row_tiles = warp_rows * micro_size_x; + auto warp_col_tiles = warp_cols * micro_size_y; + + auto lane_id = tid % warp_size; + auto tx = lane_id; + + constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; + constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; + constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size; + + constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile; + constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile; + + B_type B_local[warp_cols * kPack * local_size_b]; + + for (int ki = 0; ki < inner_k; ki++) { + // Fetch B into register + for (int j = 0; j < warp_cols; j++) { + const auto l = warp_n * warp_col_tiles + j * micro_size_y; + const auto r = ki * kPack * micro_size_k; + for (int local_id = 0; local_id < kPack * local_size_b; local_id++) { + if constexpr (TransposeB) { + auto [row, col] = reverse_index_map(lane_id, local_id); + B_local[j * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + l + row, r + col)]; + } else { + auto [row, col] = reverse_index_map_transposed(lane_id, local_id); + B_local[j * kPack * local_size_b + local_id] = + B_shared[make_swizzle_layout( + r + row, l + col)]; + } + } + } + + // Compute + for (int kp = 0; kp < kPack; kp++) { + for (int i = 0; i < warp_rows; ++i) { + for (int j = 0; j < warp_cols; ++j) { + auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); + auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * vec_size; + auto a_ptr = ((A_type *)A_local) + + (ki * warp_rows * kPack + i * kPack + kp) * vec_size; + + // Use the trait to select the correct MFMA instruction, either fp8, + // fp16 or bf16 currently + MfmaTraits::mfma_op(b_ptr, a_ptr, acc_ptr); + } + } + } + } + } +}; + +} // namespace tl + +namespace tl { + +template +TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + using Compute = + GemmTensorOp; + Compute::body(pA, pB, accum); +} + +template +TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + using Compute = + GemmTensorOp; + Compute::body_rs(pA, pB, accum); +} + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/hip/hip_fp8.h b/tilelang/original/src/tl_templates/hip/hip_fp8.h new file mode 100644 index 0000000000000000000000000000000000000000..b32f84dca22a9358a882e7ecd3a1367b9e144ca2 --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/hip_fp8.h @@ -0,0 +1,167 @@ +#include + +#define HIP_FP8_ENABLED 1 + +using fp8_e4_t = __hip_fp8_e4m3_fnuz; +using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz; + +// Additional FP8 types for compatibility +using fp8_e5_t = __hip_fp8_e5m2_fnuz; +using fp8_e5_2_t = __hip_fp8x2_e5m2_fnuz; +// Note: E8M0 types are not supported in current HIP version +// using fp8_e8_t = __hip_fp8_e8m0_fnuz; +// using fp8_e8_2_t = __hip_fp8x2_e8m0_fnuz; + +// Simple wrapper that provides member access for generated code +struct fp8_e4_4_t { + union { + __hip_fp8x4_e4m3_fnuz data; + struct { + fp8_e4_t x, y, z, w; + }; + }; + + // Default constructor + __device__ fp8_e4_4_t() = default; + + // Constructor from __hip_fp8x4_e4m3_fnuz + __device__ fp8_e4_4_t(const __hip_fp8x4_e4m3_fnuz &val) : data(val) {} + + // Constructor from float4 + __device__ fp8_e4_4_t(const float4 &val) : data(val) {} + + // Conversion operator to __hip_fp8x4_e4m3_fnuz + __device__ operator __hip_fp8x4_e4m3_fnuz() const { return data; } + + // Assignment operator + __device__ fp8_e4_4_t &operator=(const __hip_fp8x4_e4m3_fnuz &val) { + data = val; + return *this; + } +}; + +struct __align__(8) fp8_e4_8_t { + fp8_e4_4_t x; + fp8_e4_4_t y; +}; + +struct __align__(16) fp8_e4_16_t { + fp8_e4_8_t x; + fp8_e4_8_t y; +}; + +// FP8 E5M2 vector types +struct fp8_e5_4_t { + union { + __hip_fp8x4_e5m2_fnuz data; + struct { + fp8_e5_t x, y, z, w; + }; + }; + __device__ fp8_e5_4_t() = default; + __device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {} + __device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; } +}; + +struct __align__(8) fp8_e5_8_t { + fp8_e5_4_t x; + fp8_e5_4_t y; +}; + +struct __align__(16) fp8_e5_16_t { + fp8_e5_8_t x; + fp8_e5_8_t y; +}; + +// FP8 E8M0 vector types - not supported in current HIP version +/* +struct fp8_e8_4_t { + union { + __hip_fp8x4_e8m0_fnuz data; + struct { + fp8_e8_t x, y, z, w; + }; + }; + __device__ fp8_e8_4_t() = default; + __device__ fp8_e8_4_t(const __hip_fp8x4_e8m0_fnuz &val) : data(val) {} + __device__ operator __hip_fp8x4_e8m0_fnuz() const { return data; } +}; + +struct __align__(8) fp8_e8_8_t { + fp8_e8_4_t x; + fp8_e8_4_t y; +}; + +struct __align__(16) fp8_e8_16_t { + fp8_e8_8_t x; + fp8_e8_8_t y; +}; +*/ + +__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, + fp8_e4_t w) { + // reinterpret the 4 fp8_e4_t values to signed char value and shift + signed char x_char = *reinterpret_cast(&x); + signed char y_char = *reinterpret_cast(&y); + signed char z_char = *reinterpret_cast(&z); + signed char w_char = *reinterpret_cast(&w); + int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; + return *reinterpret_cast(&res); +} + +__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, + fp8_e4_t w, fp8_e4_t v, fp8_e4_t u, + fp8_e4_t t, fp8_e4_t s) { + signed char x_char = *reinterpret_cast(&x); + signed char y_char = *reinterpret_cast(&y); + signed char z_char = *reinterpret_cast(&z); + signed char w_char = *reinterpret_cast(&w); + signed char v_char = *reinterpret_cast(&v); + signed char u_char = *reinterpret_cast(&u); + signed char t_char = *reinterpret_cast(&t); + signed char s_char = *reinterpret_cast(&s); + int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; + int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char; + fp8_e4_8_t res; + res.x = *reinterpret_cast(&a); + res.y = *reinterpret_cast(&b); + return res; +} + +__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0, + fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, + fp8_e4_t y7) { + signed char x0_char = *reinterpret_cast(&x0); + signed char x1_char = *reinterpret_cast(&x1); + signed char x2_char = *reinterpret_cast(&x2); + signed char x3_char = *reinterpret_cast(&x3); + signed char x4_char = *reinterpret_cast(&x4); + signed char x5_char = *reinterpret_cast(&x5); + signed char x6_char = *reinterpret_cast(&x6); + signed char x7_char = *reinterpret_cast(&x7); + signed char y0_char = *reinterpret_cast(&y0); + signed char y1_char = *reinterpret_cast(&y1); + signed char y2_char = *reinterpret_cast(&y2); + signed char y3_char = *reinterpret_cast(&y3); + signed char y4_char = *reinterpret_cast(&y4); + signed char y5_char = *reinterpret_cast(&y5); + signed char y6_char = *reinterpret_cast(&y6); + signed char y7_char = *reinterpret_cast(&y7); + int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char; + int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char; + int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char; + int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char; + fp8_e4_8_t res_x; + res_x.x = *reinterpret_cast(&a); + res_x.y = *reinterpret_cast(&b); + fp8_e4_8_t res_y; + res_y.x = *reinterpret_cast(&c); + res_y.y = *reinterpret_cast(&d); + fp8_e4_16_t res; + res.x = res_x; + res.y = res_y; + return res; +} \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/hip/ldsm.h b/tilelang/original/src/tl_templates/hip/ldsm.h new file mode 100644 index 0000000000000000000000000000000000000000..68c1455f7fb552b63380b0f68fe5fcc8d59df7a9 --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/ldsm.h @@ -0,0 +1,3 @@ +#pragma once + +#include "common.h" \ No newline at end of file diff --git a/tilelang/original/src/tl_templates/hip/reduce.h b/tilelang/original/src/tl_templates/hip/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..16c51b648654f49bff41de0b1f55dc9691d60b94 --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/reduce.h @@ -0,0 +1,117 @@ +#pragma once + +#include "common.h" + +namespace tl { + +struct SumOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x + y; + } +}; + +struct MaxOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return ck_tile::max(x, y); + } +}; + +struct MinOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return ck_tile::min(x, y); + } +}; + +struct BitAndOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x & y; + } +}; + +struct BitOrOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x | y; + } +}; + +struct BitXorOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x ^ y; + } +}; + +template +struct SharedReduceWarp { + template + static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst, + int total_dest, int reduce_extent, int tail, + T init_value) { + if (total_dest <= 0 || reduce_extent <= 0) + return; + constexpr int kWarpSize = 64; + static_assert(Threads % kWarpSize == 0, + "SharedReduceWarp expects blockDim.x to be a multiple of " + "wave size on HIP."); + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane = tid % kWarpSize; + const int num_warps = Threads / kWarpSize; + + for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) { + const int prefix = tail == 1 ? dest_idx : dest_idx / tail; + const int suffix = tail == 1 ? 0 : dest_idx % tail; + const int src_base = (prefix * reduce_extent) * tail + suffix; + const int dst_index = prefix * tail + suffix; + + T partial = init_value; + for (int rv = lane; rv < reduce_extent; rv += kWarpSize) { + T val = src[src_base + rv * tail]; + if constexpr (UseAbs) { + val = val < T(0) ? -val : val; + } + partial = Reducer()(partial, val); + } + + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + T other = __shfl_down(partial, offset, kWarpSize); + partial = Reducer()(partial, other); + } + + if (lane == 0) { + if constexpr (NeedAccumulate) { + partial = Reducer()(dst[dst_index], partial); + } + dst[dst_index] = partial; + } + } + } +}; + +template +struct AllReduce { + static_assert(threads == 1024 || threads == 512 || threads == 256 || + threads == 128 || threads == 64 || threads == 32 || + threads == 16 || threads == 8 || threads == 4 || threads == 2); + static_assert(threads % scale == 0); + + template static __device__ T run(T x, T *red_buf = nullptr) { + constexpr int offset = threads / 2; + constexpr int warpSize = 64; + + if constexpr (offset >= warpSize) { + __syncthreads(); + red_buf[threadIdx.x] = x; + __syncthreads(); + x = Reducer()(x, red_buf[threadIdx.x ^ offset]); + } else { + x = Reducer()(x, __shfl_xor(x, offset)); + } + if constexpr (offset == scale) { + return x; + } else { + return AllReduce::run(x, red_buf); + } + } +}; + +} // namespace tl diff --git a/tilelang/original/src/tl_templates/hip/threadblock_swizzle.h b/tilelang/original/src/tl_templates/hip/threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..7771f0b98598157a51a0e992fa06c9c625fb0be7 --- /dev/null +++ b/tilelang/original/src/tl_templates/hip/threadblock_swizzle.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common.h" + +namespace tl { + +template TL_DEVICE dim3 rasterization2DRow() { + auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.x; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.x; + const unsigned int col_idx = (panel_idx & 1) + ? gridDim.x - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +template TL_DEVICE dim3 rasterization2DColumn() { + auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.y; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.y; + const unsigned int row_idx = (panel_idx & 1) + ? gridDim.y - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +} // namespace tl diff --git a/tilelang/original/src/transform/align_dynamic_shared_memory_allocations.cc b/tilelang/original/src/transform/align_dynamic_shared_memory_allocations.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c2519df99ea7c76fbb5e3226a96ff27407e801f --- /dev/null +++ b/tilelang/original/src/transform/align_dynamic_shared_memory_allocations.cc @@ -0,0 +1,159 @@ +/*! + * \file align_dynamic_shared_memory_allocations.cc + * \brief align dynamic shared memory allocations + */ + +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator { +public: + explicit TileLangAlignDynamicSharedMemoryAllocations(int align_bytes) + : align_bytes_(align_bytes) {} + + static Stmt Substitute(int align_bytes, const Stmt &stmt) { + TileLangAlignDynamicSharedMemoryAllocations smem_rewriter(align_bytes); + return smem_rewriter.VisitStmt(stmt); + } + + Stmt VisitStmt_(const AllocateNode *op) final { + auto storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn") { + auto new_extents = + MakeRoundRobinAlignment(op->extents, align_bytes_, op->dtype.bytes()); + if (!new_extents.same_as(op->extents)) { + auto new_allocate = Allocate(op->buffer_var, op->dtype, new_extents, + op->condition, op->body, op->annotations); + return StmtExprMutator::VisitStmt(new_allocate); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const BlockNode *op) final { + Block block = tvm::ffi::GetRef(op); + Array alloc_buffers = op->alloc_buffers; + alloc_buffers.MutateByApply([this](Buffer buf) { + auto storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(buf->data)); + if (storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn") { + auto new_shape = MakeRoundRobinAlignment(buf->shape, align_bytes_, + buf->dtype.bytes()); + if (!new_shape.same_as(buf->shape)) { + ObjectPtr new_buffer = + tvm::ffi::make_object(*(buf.get())); + new_buffer->shape = std::move(new_shape); + buffer_remap_.Set(buf, Buffer(new_buffer)); + return Buffer(new_buffer); + } + } + return buf; + }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } + return StmtExprMutator::VisitStmt_(block.get()); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store_node = tvm::ffi::GetRef(op); + Buffer buf = op->buffer; + if (buffer_remap_.count(buf)) { + buf = buffer_remap_[buf]; + return BufferStore(buf, op->value, op->indices); + } + return StmtExprMutator::VisitStmt_(store_node.get()); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load_node = tvm::ffi::GetRef(op); + Buffer buf = op->buffer; + if (buffer_remap_.count(buf)) { + buf = buffer_remap_[buf]; + return BufferLoad(buf, op->indices); + } + return StmtExprMutator::VisitExpr_(load_node.get()); + } + +private: + static Array MakeRoundRobinAlignment(Array extents, + int align_bytes, + int dtype_bytes) { + if (extents.empty()) + return extents; + // Calculate total number of elements + PrimExpr total_elems = make_const(extents[0].dtype(), 1); + for (auto extent : extents) { + total_elems = total_elems * extent; + } + // Calculate total bytes + PrimExpr total_bytes = total_elems * dtype_bytes; + // Check if already aligned + PrimExpr remainder = indexmod(total_bytes, align_bytes); + if (is_zero(remainder)) { + return extents; + } + // Need to pad the last dimension + Array adjusted; + for (size_t i = 0; i < extents.size(); ++i) { + adjusted.push_back(extents[i]); + } + // Calculate padded last dimension + // pad = ceil(total_bytes / align_bytes) * align_bytes + PrimExpr last_extent = extents.back(); + PrimExpr other_elems = make_const(extents[0].dtype(), 1); + for (size_t i = 0; i < extents.size() - 1; ++i) { + other_elems = other_elems * extents[i]; + } + // new_last_extent = ceil(total_bytes / align_bytes) * align_bytes / + // (other_elems * dtype_bytes) + PrimExpr padded_total_bytes = + floordiv(total_bytes + align_bytes - 1, align_bytes) * align_bytes; + PrimExpr new_last_extent = + floordiv(padded_total_bytes, other_elems * dtype_bytes); + adjusted.Set(adjusted.size() - 1, new_last_extent); + return adjusted; + } + + int align_bytes_; + Map buffer_remap_; +}; + +tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { + using namespace tir::transform; + auto pass_func = [align_bytes](PrimFunc f, const IRModule &m, + const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + n->body = TileLangAlignDynamicSharedMemoryAllocations::Substitute( + align_bytes, n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, + "tl.AlignDynamicSharedMemoryAllocations", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations", + AlignDynamicSharedMemoryAllocations); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/annotate_device_regions.cc b/tilelang/original/src/transform/annotate_device_regions.cc new file mode 100644 index 0000000000000000000000000000000000000000..ecc0cba9d2bdcc278b390cd5147478005578bcd9 --- /dev/null +++ b/tilelang/original/src/transform/annotate_device_regions.cc @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file annotate_device_regions.cc + * \brief Split device function from host. + */ +#include "tir/transforms/ir_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class DeviceRegionAnnotater : public StmtMutator { +public: + explicit DeviceRegionAnnotater(Target device_target) + : device_target_(std::move(device_target)) {} + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::attr::kTarget) { + // If a target attribute already exists, use it as-is. + return tvm::ffi::GetRef(op); + } else if (op->attr_key == tir::attr::thread_extent || + op->attr_key == tir::attr::pipeline_exec_scope || + op->attr_key == tir::attr::device_scope) { + // These attributes are only allowed in device-side code, so + // they should be annotated with the function's default target. + Stmt body = tvm::ffi::GetRef(op); + return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); + } else { + // All other annotations are ignored + return StmtMutator::VisitStmt_(op); + } + } + +private: + Target device_target_; +}; + +tvm::transform::Pass AnnotateDeviceRegions() { + using namespace tir::transform; + auto pass_func = [](PrimFunc func, const IRModule &mod, + const tvm::transform::PassContext &ctx) -> PrimFunc { + auto opt_target = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; + Target target = opt_target.value(); + Target device_target = target.WithoutHost(); + + if (target->GetHost()) { + if (device_target->kind->name == "c") { + // Annotate the function with the device target + auto func_body = func->body; + func.CopyOnWrite()->body = + AttrStmt(device_target, tvm::attr::kTarget, 0, func_body); + } + + DeviceRegionAnnotater mutator(target.WithoutHost()); + func.CopyOnWrite()->body = mutator(func->body); + } + return func; + }; + + return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions", + AnnotateDeviceRegions); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/annotate_read_only_params.cc b/tilelang/original/src/transform/annotate_read_only_params.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9eef683b5e36581f40080760d59b8fea34e3b4b --- /dev/null +++ b/tilelang/original/src/transform/annotate_read_only_params.cc @@ -0,0 +1,191 @@ +/*! + * \file annotate_read_only_params.cc + * \brief Annotate PrimFunc parameters that are read-only (never written). + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { +using namespace tir; +using namespace ffi; + +/*! + * \brief A simple visitor that marks handle parameters as written when they + * appear on the LHS of a BufferStore or in a tvm_access_ptr with write + * flag. + */ +class ReadWriteMarker : public StmtExprVisitor { +public: + explicit ReadWriteMarker( + const std::unordered_set ¶m_or_data_vars) + : param_or_data_vars_(param_or_data_vars) {} + + const std::unordered_set &written() const { + return written_; + } + + // Try to resolve the underlying buffer data Var from a pointer-like + // argument. Supports: + // - address_of(BufferLoad(...)) -> returns buffer->data + // - BufferLoad(...) -> returns buffer->data + // Otherwise returns nullptr. + const VarNode *ResolveDataVarFromPtrArg(const PrimExpr &arg) const { + if (const auto *call = arg.as()) { + if (call->op.same_as(builtin::address_of())) { + if (call->args.size() == 1U) { + if (const auto *load = call->args[0].as()) { + return load->buffer->data.get(); + } + } + } + } else if (const auto *load = arg.as()) { + return load->buffer->data.get(); + } + return nullptr; + } + + void VisitStmt_(const BufferStoreNode *op) final { + const VarNode *data = op->buffer->data.get(); + if (param_or_data_vars_.count(data)) { + written_.insert(data); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + // Detect tvm_access_ptr writes. Be conservative if rw_mask is non-constant. + if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->args.size() == 5U) { + if (const VarNode *buf = op->args[1].as()) { + const IntImmNode *flag = op->args[4].as(); + bool maybe_write = true; // default conservative + if (flag) { + maybe_write = (flag->value & 2) != 0; // write bit set + } + if (maybe_write && param_or_data_vars_.count(buf)) { + written_.insert(buf); + } + } + } + } else { + // Generic fallback: mark buffers that appear as + // address_of(BufferLoad(...)) in call arguments as written. This matches + // patterns like + // tl.tma_store(address_of(smem[..]), address_of(gmem[..]), ...) + // call_extern("AtomicAdd*", address_of(gmem[..]), ...) + // and avoids over-marking plain BufferLoad used for reads. + for (const PrimExpr &a : op->args) { + if (const auto *c = a.as()) { + if (c->op.same_as(builtin::address_of()) && c->args.size() == 1U) { + if (const auto *bl = c->args[0].as()) { + const VarNode *data = bl->buffer->data.get(); + if (param_or_data_vars_.count(data)) { + written_.insert(data); + } + } + } + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + +private: + std::unordered_set param_or_data_vars_; + std::unordered_set written_; +}; + +/*! + * \brief Annotate PrimFunc with indices of read-only handle parameters. + * + * Adds an Array attribute "tl.readonly_param_indices" that lists + * parameter indices which correspond to handle parameters that are never + * written inside the function body. This can be used by codegen to emit + * `const` qualifiers to enable read-only caching (e.g., __ldg on CUDA). + */ +static tir::PrimFunc MarkReadOnlyParams(tir::PrimFunc f) { + // Gather handle params and their corresponding buffer data vars (aliases). + std::unordered_set param_or_data_vars; + // Map back from data var to parameter index for result attribution. + std::unordered_map data_var_to_param_idx; + + for (size_t i = 0; i < f->params.size(); ++i) { + const Var &p = f->params[i]; + if (!p->dtype.is_handle()) + continue; + param_or_data_vars.insert(p.get()); + // If there is a buffer_map entry for this param, include its data var too. + if (auto opt = f->buffer_map.Get(p)) { + const VarNode *data = opt.value()->data.get(); + param_or_data_vars.insert(data); + data_var_to_param_idx[data] = i; + } + } + if (param_or_data_vars.empty()) + return f; + + ReadWriteMarker marker(param_or_data_vars); + marker(f->body); + + // Determine read-only parameter indices among all params (handle only) + Array readonly_indices; + for (size_t i = 0; i < f->params.size(); ++i) { + const Var &v = f->params[i]; + if (!v->dtype.is_handle()) + continue; + + bool is_written = false; + // Direct param var written? + if (marker.written().count(v.get())) { + is_written = true; + } else { + // Or any aliased data var written? + if (auto opt = f->buffer_map.Get(v)) { + if (marker.written().count(opt.value()->data.get())) { + is_written = true; + } + } + } + + if (!is_written) { + readonly_indices.push_back(Integer(static_cast(i))); + } + } + + if (!readonly_indices.empty()) { + Map attrs; + attrs.Set(String("tl.readonly_param_indices"), readonly_indices); + f = WithAttrs(std::move(f), attrs); + } + return f; +} + +namespace transform { +using namespace tir::transform; + +Pass AnnotateReadOnlyParams() { + auto pass_func = [](PrimFunc f, const IRModule &m, + const tvm::transform::PassContext &ctx) { + return MarkReadOnlyParams(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateReadOnlyParams", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.AnnotateReadOnlyParams", + AnnotateReadOnlyParams); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/annotate_warp_group_reg_alloc.cc b/tilelang/original/src/transform/annotate_warp_group_reg_alloc.cc new file mode 100644 index 0000000000000000000000000000000000000000..08be53f205ebdbc8f577b57ca4c30e8eed53a81b --- /dev/null +++ b/tilelang/original/src/transform/annotate_warp_group_reg_alloc.cc @@ -0,0 +1,198 @@ +/*! + * \file annotate_warp_group_reg_alloc.cc + * \brief Annotate warp group reg alloc for warp specialization + */ + +#include "warp_specialized_rewriter.h" +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class SetMaxNRegCollector : public StmtExprVisitor { +public: + static Array Collect(const PrimFunc &f) { + SetMaxNRegCollector collector; + collector(f->body); + if (collector.warp_specialized_) { + return Array({}); + } + return collector.has_no_set_max_nreg_ + ? Array({IntImm(DataType::Int(32), -1), + IntImm(DataType::Int(32), -1)}) + : collector.nreg_; + } + +private: + void VisitStmt_(const EvaluateNode *op) final { + if (const CallNode *call = op->value.as()) { + if (call->op.same_as(set_max_nreg())) { + auto reg_hint = call->args[0].as()->value; + auto is_inc = call->args[1].as()->value; + ICHECK(reg_hint <= 240 && reg_hint >= 24) + << "Invalid reg hint: " << reg_hint; + ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc; + + // producer should decrease register hint while consumer should increase + // register hint + nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint)); + } else if (call->op.same_as(no_set_max_nreg())) { + has_no_set_max_nreg_ = true; + } + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == attr::kCustomWarpSpecialization) { + warp_specialized_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + + Array nreg_{IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), 0)}; + bool has_no_set_max_nreg_ = false; + bool warp_specialized_ = false; +}; + +class SimtCopyDetector : public StmtExprVisitor { +public: + static bool Detect(const Stmt &stmt) { + SimtCopyDetector detector; + detector.VisitStmt(stmt); + return detector.has_simt_copy_; + } + +private: + void VisitStmt_(const BufferStoreNode *op) final { + auto scope = + runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (scope.to_string() != "global") { + has_simt_copy_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + + bool has_simt_copy_{false}; +}; + +class SetMaxNRegInjector : public StmtExprMutator { +public: + static PrimFunc Inject(PrimFunc f) { + auto T = SetMaxNRegInjector(); + T.nreg_ = SetMaxNRegCollector::Collect(f); + if (T.nreg_.empty()) { + return f; + } + f.CopyOnWrite()->body = T(f->body); + return f; + } + +private: + Stmt VisitStmt_(const EvaluateNode *op) final { + if (const CallNode *call = op->value.as()) { + if (call->op.same_as(no_set_max_nreg())) { + // Remove the original set_max_nreg calls as they will be re-inserted + // at appropriate locations + return Evaluate(0); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent && + Downcast(op->node)->thread_tag == "threadIdx.x") { + thread_iv_ = Downcast(op->node); + need_update_thread_extent_ = false; + AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (need_update_thread_extent_) { + thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()}; + attr_stmt.CopyOnWrite()->node = thread_iv_; + attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value(); + } + thread_iv_ = {}; + return attr_stmt; + } else if (op->attr_key == attr::kWarpSpecializationScope) { + auto if_then_else = Downcast(op->body); + if (!if_then_else.defined()) { + return StmtExprMutator::VisitStmt_(op); + } + auto producer_body = if_then_else->then_case; + Optional consumer_body = if_then_else->else_case; + // In some degenerate warp-specialized patterns (e.g., producer-only), + // the consumer body may be absent. Handle gracefully by only annotating + // the producer side when consumer is missing. + + auto dec_reg = nreg_[0].as()->value; + auto inc_reg = nreg_[1].as()->value; + + auto inc_reg_stmt = Evaluate(0); + auto dec_reg_stmt = Evaluate(0); + + // Only inject if we have valid register hints and no SIMT copy + bool has_simt_copy = SimtCopyDetector::Detect(producer_body); + + if (dec_reg == 0 && inc_reg == 0 && !has_simt_copy) { + auto inc_reg_num = IntImm(DataType::Int(32), 240); + auto dec_reg_num = IntImm(DataType::Int(32), 24); + inc_reg_stmt = Evaluate( + Call(DataType::Handle(), set_max_nreg(), {inc_reg_num, 1})); + dec_reg_stmt = Evaluate( + Call(DataType::Handle(), set_max_nreg(), {dec_reg_num, 0})); + } + + // Inject register setting statements + Array producer_stmts; + producer_stmts.push_back(dec_reg_stmt); + producer_stmts.push_back(producer_body); + auto new_producer_body = SeqStmt(producer_stmts); + + Stmt new_if_stmt; + if (consumer_body.defined()) { + Array consumer_stmts; + consumer_stmts.push_back(inc_reg_stmt); + consumer_stmts.push_back(consumer_body.value()); + auto new_consumer_body = SeqStmt(consumer_stmts); + new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body, + new_consumer_body); + } else { + // No consumer branch; keep the if-then form. + new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body); + } + + auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt); + return new_attr; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Array nreg_; + IterVar thread_iv_; + Optional updated_thread_extent_; + bool need_update_thread_extent_ = false; +}; + +using namespace tir::transform; + +tvm::transform::Pass AnnotateWarpGroupRegAlloc() { + auto pass_func = [](PrimFunc f, const IRModule &m, + const PassContext &ctx) -> PrimFunc { + return SetMaxNRegInjector::Inject(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc", + AnnotateWarpGroupRegAlloc); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/arg_binder.cc b/tilelang/original/src/transform/arg_binder.cc new file mode 100644 index 0000000000000000000000000000000000000000..294c9f6bc6cafc8c9eb3deb5fc152fa371492853 --- /dev/null +++ b/tilelang/original/src/transform/arg_binder.cc @@ -0,0 +1,928 @@ +/*! + * \file arg_binder.cc + * \brief Helper utility to match and bind arguments. + */ +#include "arg_binder.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "../runtime/error_helpers.h" +#include "tir/transforms/ir_utils.h" +#include "tvm/arith/int_solver.h" +#include "tvm/ffi/cast.h" +#include "tvm/ffi/container/array.h" +#include "tvm/tir/stmt.h" +#include "tvm/tir/stmt_functor.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, + const std::string &arg_name, std::vector *asserts, + PrimExpr nullable_guard = PrimExpr()) { + PrimExpr scond = ana->Simplify(cond); + if (is_zero(scond)) { + LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " + << " on argument " << arg_name; + } + + if (!is_one(scond)) { + // Extract kernel/buffer/field from arg_name (e.g., "main.A.shape[0]") + std::string kernel = arg_name; + std::string buf_and_field = arg_name; + size_t dot_pos = arg_name.find('.'); + if (dot_pos != std::string::npos) { + kernel = arg_name.substr(0, dot_pos); + buf_and_field = arg_name.substr(dot_pos + 1); + } + std::string buffer = buf_and_field; + std::string field; + size_t dot2 = buf_and_field.find('.'); + if (dot2 != std::string::npos) { + buffer = buf_and_field.substr(0, dot2); + field = buf_and_field.substr(dot2 + 1); + } + + // If cond is an equality, prefer structured packed error with expect/got + if (const auto *eq = scond.as()) { + PrimExpr lhs = eq->a; + PrimExpr rhs = eq->b; + // Choose rhs as expected and lhs as got for better semantics in most + // binding cases + ffi::Array pargs; + pargs.push_back(StringImm(tvm_error_expect_eq)); + pargs.push_back(StringImm(kernel)); + pargs.push_back(StringImm(buffer)); + pargs.push_back(StringImm(field.empty() ? std::string("value") : field)); + pargs.push_back(cast(DataType::Int(64), rhs)); // expected + pargs.push_back(cast(DataType::Int(64), lhs)); // got + + Stmt call_err = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); + // Only emit at runtime when the equality fails + Stmt inner = IfThenElse(Not(scond), call_err); + if (nullable_guard.defined()) { + inner = IfThenElse(Not(nullable_guard), inner); + } + asserts->emplace_back(SeqStmt({inner, Evaluate(0)})); + } else { + // Fallback: packed generic constraint violation without dumping cond + ffi::Array pargs; + pargs.push_back(StringImm(tvm_error_constraint_violation)); + pargs.push_back(StringImm(kernel)); + pargs.push_back(StringImm(buffer)); + pargs.push_back(StringImm(field.empty() ? std::string("value") : field)); + Stmt call_err = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); + Stmt inner = IfThenElse(Not(scond), call_err); + if (nullable_guard.defined()) { + inner = IfThenElse(Not(nullable_guard), inner); + } + asserts->emplace_back(SeqStmt({inner, Evaluate(0)})); + } + } +} + +std::vector ArgBinder::getUndefVars(const std::vector &args) { + std::unordered_set visit; + std::vector res; + for (const auto &arg : args) { + PostOrderVisit(arg, [&](ObjectRef r) { + if (auto var = r.as()) { + if (!visit.count(var)) { + visit.insert(var); + } + auto it = def_map_->find(var); + if (it == def_map_->end()) { + // res.push_back(var); + res.push_back(ffi::GetRef(var)); + } + } + }); + } + return res; +} + +bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets, + const PrimExpr &nullable_guard) { + // Currently only used in BindDLTensor, nullable_guard is already a defined + // bool, so use it directly. + auto MakeGuarded = [&](PrimExpr basic) -> PrimExpr { + // is_null || basic + return Or(nullable_guard, basic); + }; + ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; + auto BindVar = [&](const VarNode *v, PrimExpr value) { + auto v_arg = ffi::GetRef(v); + defs_.emplace_back(v_arg); + if (with_lets) { + (*def_map_)[v] = value; + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + } else { + (*def_map_)[v] = value; + } + }; + // 1. simple binding var = value + if (const VarNode *v = arg.as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + BindVar(v, value); + // First time binding: identical behavior as Bind_ + return true; + } else { + // Second or later binding: add is_null short-circuit + PrimExpr cond = value == it->second; + BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard); + } + } else { + // 2. complex binding expr = value + // get undefined variables + auto undefs = ffi::Array(getUndefVars({arg})); + if (!undefs.empty()) { + // if value is not integer, such as float, we are unable to solve it + if (!value.dtype().is_int() && !value.dtype().is_uint()) { + LOG(FATAL) << "Unable to solve non-integer variables " << undefs + << " from equation `" << value << "`"; + } + arith::IntConstraints constraints(undefs, {}, {arg == value}); + auto sol = arith::SolveLinearEquations(constraints); + if (!sol->dst->variables.empty()) { + LOG(FATAL) << "TVM is unable to solve variables " << undefs + << " from equation " << constraints; + } + for (const auto &v : undefs) { + auto value_opt = sol->src_to_dst.Get(v); + ICHECK(value_opt->defined()) + << "Unable to solve variable `" << v << "` from expression `" + << (value == arg) << "`"; + auto value = ffi::GetRef(sol->src_to_dst.Get(v)->get()); + BindVar(v.as(), value); + } + } + // we must add the assert again + // because the solved expression may contain floordiv (e.g. 3 * m == n + // ==> m = n // 3) we re-compute the constraint to verify the solution + // is correct + PrimExpr cond = value == arg; + BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard); + } + // ICHECK(false); + return false; +} + +bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets) { + ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; + if (const VarNode *v = arg.as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + Var v_arg = Downcast(arg); + defs_.emplace_back(v_arg); + if (with_lets) { + (*def_map_)[v] = arg; + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + } else { + (*def_map_)[v] = value; + } + return true; + } else { + BinderAddAssert(&analyzer_, value == it->second, arg_name, &asserts_); + } + } else { + BinderAddAssert(&analyzer_, value == arg, arg_name, &asserts_); + } + return false; +} + +void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_let) { + Bind_(arg, value, arg_name, with_let); +} + +void ArgBinder::BindArray(const ffi::Array &arg, + const ffi::Array &value, + const std::string &arg_name) { + ICHECK_EQ(arg.size(), value.size()) + << "Argument " << arg_name << " array size mismatch"; + for (size_t i = 0; i < arg.size(); ++i) { + std::ostringstream os; + os << arg_name << "[" << i << "]"; + this->Bind(arg[i], value[i], os.str()); + } +} + +void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value, + const std::string &arg_name, bool fuzzy_match) { + ICHECK_EQ(arg.scope(), value.scope()) + << "Argument " << arg_name << " Buffer bind scope mismatch"; + // Relax dtype check to allow FP8 E4M3 variants to bind together. + auto dtype_compatible = [](DataType expected, DataType provided) -> bool { + if (expected == provided) + return true; + // If expected is float8_e4m3, allow float8_e4m3fn/float8_e4m3fnuz as well. + if (expected.is_float8_e4m3()) { + return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() || + provided.is_float8_e4m3fnuz(); + } + // If expected is float8_e5m2, allow float8_e5m2fnuz as well. + if (expected.is_float8_e5m2()) { + return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz(); + } + // If expected is bool, allow binding from int8/uint8 with same lanes. + if (expected.is_bool()) { + bool is_i8 = provided.is_int() && provided.bits() == 8; + bool is_u8 = provided.is_uint() && provided.bits() == 8; + return (is_i8 || is_u8) && expected.lanes() == provided.lanes(); + } + return false; + }; + ICHECK(dtype_compatible(arg->dtype, value->dtype)) + << "Argument " << arg_name << " Buffer bind data type mismatch: expected " + << arg->dtype << ", got " << value->dtype; + if (value->data_alignment % arg->data_alignment != 0) { + LOG(WARNING) << "Trying to bind buffer to another one with lower alignment " + "requirement " + << " required_alignment=" << arg->data_alignment + << ", provided_alignment=" << value->data_alignment; + } + + if (value->elem_offset.defined()) { + // bind pointer and offset. + if (is_zero(arg->elem_offset)) { + ICHECK(is_zero(value->elem_offset)) + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << arg->elem_offset + << ", provided elem_offset=" << value->elem_offset; + } + + this->Bind(arg->data, value->data, arg_name + ".data"); + if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", + false)) { + if (arg->offset_factor > 1) { + PrimExpr offset = value->elem_offset; + PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); + BinderAddAssert(&analyzer_, zero == truncmod(offset, factor), + arg_name + ".elem_offset", &asserts_); + } + } + } + + if (arg->shape.size() < value->shape.size()) { + ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; + size_t diff = value->shape.size() - arg->shape.size(); + for (size_t i = 0; i < diff; ++i) { + ICHECK(is_one(analyzer_.Simplify(value->shape[i]))) + << "Argument " << arg_name << " shape mismatch" << arg->shape + << " vs " << value->shape; + } + for (size_t i = 0; i < arg->shape.size(); ++i) { + std::ostringstream os; + os << arg_name << ".shape[" << i << "]"; + this->Bind(arg->shape[i], value->shape[i + diff], os.str()); + } + if (!value->strides.empty()) { + ICHECK_EQ(arg->strides.size(), arg->shape.size()); + ICHECK_EQ(value->strides.size(), value->shape.size()); + for (size_t i = 0; i < arg->strides.size(); ++i) { + std::ostringstream os; + os << arg_name << ".strides[" << i << "]"; + this->Bind(arg->strides[i], value->strides[i + diff], os.str()); + } + } + } else { + this->BindArray(arg->shape, value->shape, arg_name + ".shape"); + this->BindArray(arg->strides, value->strides, arg_name + ".strides"); + } +} + +inline PrimExpr TVMArrayGet(DataType t, Var arr, + builtin::TVMStructFieldKind kind) { + return TVMStructGet(t, arr, 0, kind); +} + +void ArgBinder::BindDLTensors( + const std::vector> &buffer_def, + const PrimExpr &device_type, const PrimExpr &device_id, + const std::string &func_name, + const std::unordered_set &used_param_buffers) { + ffi::Array buffers; + ffi::Array handles; + + // First pass: collect shape var -> list of (buffer_name, dim_idx, handle_ptr) + struct ShapeVarSource { + std::string buf_name; + size_t dim_idx; + const VarNode *handle_ptr; // Raw pointer to check used_param_buffers + }; + std::unordered_map> + shape_var_sources; + + for (const auto &[handle, buffer] : buffer_def) { + std::string arg_name = func_name + "." + buffer->data->name_hint; + + // Scan buffer shape for symbolic variables + for (size_t k = 0; k < buffer->shape.size(); ++k) { + if (buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { + break; + } + + if (const VarNode *v = buffer->shape[k].as()) { + // This dimension is a symbolic variable + shape_var_sources[v].push_back({arg_name, k, handle.get()}); + } + } + } + + // Second pass: Create is_null vars and shape buffers for all buffers first + std::unordered_map is_null_map; + std::unordered_map shape_buffer_map; + std::unordered_map + is_null_expr_map; // arg_name -> is_null expression (const_false for used + // buffers) + + const DataType tvm_shape_type = DataType::ShapeIndex(); + const DataType tvm_ndim_type = DataType::Int(32); + const Stmt nop = Evaluate(0); + + // Create all is_null vars and shape buffers first + for (const auto &[handle, buffer] : buffer_def) { + bool is_used = used_param_buffers.count(handle.get()); + std::string arg_name = func_name + "." + buffer->data->name_hint; + + Var is_null_var(arg_name + "_is_null", DataType::Bool()); + init_nest_.emplace_back( + LetStmt(is_null_var, + Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); + const PrimExpr &is_null = is_used ? const_false() : is_null_var; + + is_null_map[arg_name] = is_null_var; + is_null_expr_map[arg_name] = is_null; + + if (is_used) { + init_nest_.emplace_back( + AssertStmt(!is_null_var, + tvm::tir::StringImm( + arg_name + " is expected to have non-NULL pointer"), + nop)); + } + } + + // Create all shape buffers before binding any shapes + for (const auto &[handle, buffer] : buffer_def) { + std::string arg_name = func_name + "." + buffer->data->name_hint; + const PrimExpr &is_null = is_null_expr_map[arg_name]; + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + + // shape field + Buffer buf_shape = + decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, + tvm_shape_type, shape_handle_name()); + def_handle_dtype_.Set(buf_shape->data, make_const(tvm_shape_type, 0)); + // Use if_then_else for NULL guard on the shape pointer itself, avoiding + // dereferencing TVMStructGet(handle, kArrShape) when handle is NULL. + init_nest_.emplace_back( + LetStmt(buf_shape->data, + tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), + make_zero(DataType::Handle())), + nop)); + init_nest_.emplace_back(DeclBuffer(buf_shape, nop)); + + // Save for later use in shape binding + shape_buffer_map[arg_name] = buf_shape; + } + + // Now process each buffer fully + for (const auto &[handle, buffer] : buffer_def) { + bool is_used = used_param_buffers.count(handle.get()); + std::string arg_name = func_name + "." + buffer->data->name_hint; + const PrimExpr &is_null = is_null_expr_map[arg_name]; + + // dimension checks + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + auto stride_handle_name = [&]() { return arg_name + ".strides"; }; + auto array_element_name = [&](const std::string &arr_name, size_t k) { + std::stringstream ss; + ss << arr_name << '[' << k << ']'; + return ss.str(); + }; + auto shape_element_name = [&](size_t k) { + return array_element_name(shape_handle_name(), k); + }; + auto stride_element_name = [&](size_t k) { + return array_element_name(stride_handle_name(), k); + }; + + PrimExpr a_ndim = + make_const(tvm_ndim_type, static_cast(buffer->shape.size())); + // Build clearer ndim message with kernel/buffer names + std::string kernel_nm = arg_name; + std::string buf_nm = arg_name; + size_t dot_pos = arg_name.find('.'); + if (dot_pos != std::string::npos) { + kernel_nm = arg_name.substr(0, dot_pos); + buf_nm = arg_name.substr(dot_pos + 1); + } + // Only check ndim when handle is non-NULL: use packed error helper + PrimExpr ndim_ok = (a_ndim == v_ndim); + ffi::Array ndim_args; + ndim_args.push_back(StringImm(tvm_error_ndim_mismatch)); + ndim_args.push_back(StringImm(kernel_nm)); + ndim_args.push_back(StringImm(buf_nm)); + ndim_args.push_back(cast(DataType::Int(64), a_ndim)); + ndim_args.push_back(cast(DataType::Int(64), v_ndim)); + Stmt ndim_call = Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args)); + init_nest_.emplace_back( + SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call), + Evaluate(0)), + nop})); + // type checks + // Guard all dtype field loads by `is_null` using if_then_else + PrimExpr v_type_code = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode), + IntImm(DataType::UInt(8), buffer->dtype.code())); + PrimExpr v_type_bits = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits), + IntImm(DataType::UInt(8), buffer->dtype.bits())); + PrimExpr v_type_lanes = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes), + IntImm(DataType::UInt(16), buffer->dtype.lanes())); + PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code()); + PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits()); + PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes()); + + PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + + // Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime. + if (buffer->dtype.is_float8_e4m3()) { + PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3); + PrimExpr code_e4m3fn = + IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn); + PrimExpr code_e4m3fnuz = + IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz); + PrimExpr code_match = + (v_type_code == code_e4m3 || v_type_code == code_e4m3fn || + v_type_code == code_e4m3fnuz); + cond = cond || (code_match && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + } + // Allow float8_e5m2 to match float8_e5m2fnuz at runtime. + if (buffer->dtype.is_float8_e5m2()) { + PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2); + PrimExpr code_e5m2fnuz = + IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz); + PrimExpr code_match = + (v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz); + cond = cond || (code_match && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + } + // Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6). + if (buffer->dtype.is_bool()) { + PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); + PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt); + PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6); + PrimExpr bits8 = IntImm(DataType::UInt(8), 8); + PrimExpr bits1 = IntImm(DataType::UInt(8), 1); + PrimExpr lanes_ok = (v_type_lanes == expect_lanes); + PrimExpr int8_ok = + (v_type_code == code_int && v_type_bits == bits8 && lanes_ok); + PrimExpr uint8_ok = + (v_type_code == code_uint && v_type_bits == bits8 && lanes_ok); + // Some frontends may tag bool tensors as kDLBool(code=6), commonly with + // bits=8 or bits=1. + PrimExpr kdlbool8_ok = + (v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok); + PrimExpr kdlbool1_ok = + (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); + // Also accept any dtype whose bitwidth=1, regardless of code, to be + // defensive. + PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); + cond = + cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; + } + // Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for + // FP4). + if (buffer->dtype.is_float4()) { + PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); + PrimExpr bits8 = IntImm(DataType::UInt(8), 8); + // For FP4, we pack 2 elements per byte, but we still use same lanes at + // storage level Accept int8 with same lanes as the fp4 type + PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes); + PrimExpr int8_ok = + (v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok); + cond = cond || int8_ok; + } + if (!(buffer->dtype == DataType::Int(1) || + buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) { + // Build FFI packed call to __tvm_error_dtype_mismatch when mismatch + // occurs. Only issue the call when handle is non-NULL and cond is false. + ffi::Array packed_args; + packed_args.push_back(StringImm(tvm_error_dtype_mismatch)); + // Split arg_name of the form "." into parts for clearer + // diagnostics + std::string kernel_name = arg_name; + std::string buffer_name = arg_name; + size_t dot_pos = arg_name.find('.'); + if (dot_pos != std::string::npos) { + kernel_name = arg_name.substr(0, dot_pos); + buffer_name = arg_name.substr(dot_pos + 1); + } + packed_args.push_back(StringImm(kernel_name)); + packed_args.push_back(StringImm(buffer_name)); + + auto i64 = DataType::Int(64); + // Cast to int64 for FFI function signature + packed_args.push_back(cast(i64, v_type_code)); // actual_code + packed_args.push_back(cast(i64, v_type_bits)); // actual_bits + packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes + packed_args.push_back(cast(i64, expect_code)); // expect_code + packed_args.push_back(cast(i64, expect_bits)); // expect_bits + packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes + + Stmt call_err = Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args)); + // Guard the call: only when handle is not null and cond fails + Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err); + asserts_.emplace_back(SeqStmt({guarded, nop})); + } + + // Get the pre-created shape buffer + Buffer buf_shape = shape_buffer_map[arg_name]; + + // Bind symbolic variables from buffer shape + for (size_t k = 0; k < buffer->shape.size(); ++k) { + // These packed-bit dtype shapes were not bound in the original + // implementation, so we just use them as is. + if (buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { + break; + } + + // The "real" runtime shape value read from DLTensor + PrimExpr shape_val = + cast(buffer->shape[k].dtype(), + BufferLoad(buf_shape, + {IntImm(DataType::Int(32), static_cast(k))})); + + // Check if this dimension is a symbolic variable + if (const VarNode *v = buffer->shape[k].as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + // First time binding this symbolic variable + auto sources_it = shape_var_sources.find(v); + if (sources_it != shape_var_sources.end() && + sources_it->second.size() > 1) { + // This variable appears in multiple buffers + // Assert that at least one buffer is non-null + PrimExpr any_nonnull = const_false(); + for (const auto &src : sources_it->second) { + bool buf_is_used = used_param_buffers.count(src.handle_ptr); + if (buf_is_used) { + any_nonnull = const_true(); + break; + } + Var src_is_null = is_null_map[src.buf_name]; + any_nonnull = Or(any_nonnull, Not(src_is_null)); + } + + std::ostringstream err_msg; + err_msg << "Symbolic shape variable " + << ffi::GetRef(v)->name_hint + << " requires at least one non-null buffer among: "; + bool first = true; + for (const auto &src : sources_it->second) { + if (!first) + err_msg << ", "; + err_msg << src.buf_name; + first = false; + } + + init_nest_.emplace_back(AssertStmt( + any_nonnull, tvm::tir::StringImm(err_msg.str()), nop)); + + // Build cascaded if_then_else: if !is_null_a then a.shape[k] else + // if !is_null_b then b.shape[k] ... We need to construct this in + // reverse order + PrimExpr cascaded_value; + bool is_first_source = true; + + for (auto rit = sources_it->second.rbegin(); + rit != sources_it->second.rend(); ++rit) { + const auto &src = *rit; + + // Get the shape buffer for this source + auto it_buf = shape_buffer_map.find(src.buf_name); + if (it_buf == shape_buffer_map.end()) { + LOG(FATAL) << "Shape buffer not found for " << src.buf_name; + } + Buffer src_shape_buf = it_buf->second; + + // Construct the shape load + PrimExpr src_shape_val = + cast(buffer->shape[k].dtype(), + BufferLoad(src_shape_buf, + {IntImm(DataType::Int(32), + static_cast(src.dim_idx))})); + + // Check if this buffer is used (non-nullable) + bool src_is_used = used_param_buffers.count(src.handle_ptr); + + if (is_first_source) { + // Base case: use this shape value directly (we know at least + // one is non-null from assert) + cascaded_value = src_shape_val; + is_first_source = false; + } else { + // if !is_null then use this shape, else use previous cascaded + // value But if buffer is used (non-nullable), always use its + // shape + if (src_is_used) { + cascaded_value = src_shape_val; + } else { + Var src_is_null = is_null_map[src.buf_name]; + cascaded_value = tvm::if_then_else( + Not(src_is_null), src_shape_val, cascaded_value); + } + } + } + + // Bind the variable to the cascaded expression + Var v_arg = ffi::GetRef(v); + defs_.emplace_back(v_arg); + (*def_map_)[v] = cascaded_value; + init_nest_.emplace_back( + LetStmt(v_arg, cascaded_value, Evaluate(0))); + } else { + // Single source or no special handling needed, use the original + // nullable binding + BindNullable(buffer->shape[k], shape_val, shape_element_name(k), + true, is_null); + } + } else { + // Variable already bound, add assertion with nullable guard + PrimExpr cond = (it->second == shape_val); + BinderAddAssert(&analyzer_, cond, shape_element_name(k), &asserts_, + is_null); + } + } else { + // Constant dimension, just add assertion + BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true, + is_null); + } + } + + // strides field + Buffer buf_strides = + decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, + tvm_shape_type, arg_name + ".strides"); + def_handle_dtype_.Set(buf_strides->data, + tir::TypeAnnotation(tvm_shape_type)); + init_nest_.emplace_back( + LetStmt(buf_strides->data, + tvm::if_then_else(Not(is_null), + TVMArrayGet(DataType::Handle(), handle, + builtin::kArrStrides), + make_zero(DataType::Handle())), + nop)); + init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); + PrimExpr v_strides_is_null = + Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); + + if (buffer->strides.empty()) { + // Assert the buffer is compact + DataType stype = buffer->DefaultIndexType(); + PrimExpr expect_stride = make_const(stype, 1); + ffi::Array conds; + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + PrimExpr svalue = + cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), + static_cast(k))})); + conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue); + expect_stride = expect_stride * buffer->shape[k]; + } + std::ostringstream stride_err_msg; + stride_err_msg + << stride_handle_name() + << ": expected to be compact array, but got non-compact strides"; + if (!conds.empty()) { + PrimExpr all_ok = + foldl([](PrimExpr a, PrimExpr b, + Span span) { return logical_and(a, b, span); }, + const_true(1), conds); + // Packed generic violation for non-compact strides + std::string kernel_nm3 = arg_name; + std::string buf_nm3 = arg_name; + size_t dot_pos3 = arg_name.find('.'); + if (dot_pos3 != std::string::npos) { + kernel_nm3 = arg_name.substr(0, dot_pos3); + buf_nm3 = arg_name.substr(dot_pos3 + 1); + } + ffi::Array pargs4; + pargs4.push_back(StringImm(tvm_error_constraint_violation)); + pargs4.push_back(StringImm(kernel_nm3)); + pargs4.push_back(StringImm(buf_nm3)); + pargs4.push_back(StringImm("strides")); + Stmt call_err4 = Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4)); + // Only check when strides array is present and condition fails + Stmt check = + IfThenElse(Not(v_strides_is_null), + IfThenElse(Not(all_ok), call_err4), Evaluate(0)); + asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); + } + } else if (buffer->buffer_type == kAutoBroadcast) { + PrimExpr stride_from_shape = 1; + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + DataType stride_dtype = buffer->strides[k].dtype(); + PrimExpr explicit_stride = + cast(stride_dtype, + BufferLoad(buf_strides, + {IntImm(DataType::Int(32), static_cast(k))})); + + PrimExpr stride_val = tvm::if_then_else( + v_strides_is_null, stride_from_shape, explicit_stride); + + BindNullable(buffer->strides[k], stride_val, stride_element_name(k), + true, is_null); + } + } else { + PrimExpr stride_from_shape = 1; + + for (int k = static_cast(buffer->strides.size()) - 1; k >= 0; --k) { + DataType stride_dtype = buffer->strides[k].dtype(); + PrimExpr explicit_stride = + cast(stride_dtype, + BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr shape_stride = + cast(stride_dtype, + BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); + + PrimExpr stride_val = tvm::if_then_else( + v_strides_is_null, stride_from_shape, explicit_stride); + + BindNullable(buffer->strides[k], stride_val, stride_element_name(k), + true, is_null); + } + } + + // Byte_offset field. + int data_bytes = GetVectorBytes(buffer->dtype); + + if (const auto *const_offset = buffer->elem_offset.as()) { + // Constant elem_offset: only need consistency check, no need for + // additional Var binding. + PrimExpr actual_byte_offset = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + make_const(DataType::UInt(64), 0)); + PrimExpr expect_byte_offset = + make_const(DataType::UInt(64), const_offset->value * data_bytes); + PrimExpr ok = (expect_byte_offset == actual_byte_offset); + ffi::Array pargs; + pargs.push_back(StringImm(tvm_error_byte_offset_mismatch)); + pargs.push_back(StringImm(kernel_nm)); + pargs.push_back(StringImm(buf_nm)); + pargs.push_back(cast(DataType::Int(64), expect_byte_offset)); + pargs.push_back(cast(DataType::Int(64), actual_byte_offset)); + Stmt call_err = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); + asserts_.emplace_back(SeqStmt( + {IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)), + nop})); + } else { + PrimExpr actual_byte_offset = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + make_const(DataType::UInt(64), 0)); + PrimExpr expect_elem_off = cast( + buffer->elem_offset.dtype(), + (actual_byte_offset / make_const(DataType::UInt(64), data_bytes))); + + BindNullable(buffer->elem_offset, expect_elem_off, + arg_name + ".elem_offset", true, is_null); + + if (buffer->offset_factor > 1) { + PrimExpr offset = buffer->elem_offset; + PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); + BindNullable(offset, truncmod(offset, factor), + arg_name + ".elem_offset", true, is_null); + } + } + + // device info. + // Define device_id from handle when available (so later passes can use it) + PrimExpr actual_dev_type = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), + make_zero(DataType::Int(32))); + PrimExpr actual_dev_id = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), + make_zero(DataType::Int(32))); + + // Bind device_id to a safe expression (0 when NULL handle) + BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true, + is_null); + // Check device_type consistency (device_id equality is implicitly ensured + // by binding above) + { + PrimExpr ok = (device_type == actual_dev_type); + ffi::Array pargs2; + pargs2.push_back(StringImm(tvm_error_device_type_mismatch)); + pargs2.push_back(StringImm(kernel_nm)); + pargs2.push_back(StringImm(buf_nm)); + pargs2.push_back(cast(DataType::Int(64), device_type)); + pargs2.push_back(cast(DataType::Int(64), actual_dev_type)); + Stmt call_err2 = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2)); + asserts_.emplace_back( + SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), + Evaluate(0)), + Evaluate(0)})); + } + + // Data field. Because the validation of the data field may depend + // on a dynamic size defined by the other DLTensor* parameters, this + // field must be generated last. + // Bind data pointer using expression-level guard to avoid deref on NULL. + { + Var vptr(buffer->data); + PrimExpr data_ptr = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), + make_zero(DataType::Handle())); + BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null); + + // Check if the data pointer is NULL. This check is skipped for + // size-0 arrays and also skipped when handle itself is NULL. + PrimExpr alloc_size = IntImm(buffer->DefaultIndexType(), 1); + for (const auto &dim : buffer->shape) { + alloc_size = alloc_size * dim; + } + // Improve message: kernel/buffer naming for data pointer null check + std::string kernel_nm2 = arg_name; + std::string buf_nm2 = arg_name; + size_t dot_pos2 = arg_name.find('.'); + if (dot_pos2 != std::string::npos) { + kernel_nm2 = arg_name.substr(0, dot_pos2); + buf_nm2 = arg_name.substr(dot_pos2 + 1); + } + // expand combined condition via nested IfThenElse for portability + ffi::Array pargs3; + pargs3.push_back(StringImm(tvm_error_null_ptr)); + pargs3.push_back(StringImm(kernel_nm2)); + pargs3.push_back(StringImm(buf_nm2)); + pargs3.push_back(StringImm("data pointer")); + Stmt call_err3 = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3)); + asserts_.emplace_back(SeqStmt( + {IfThenElse(Not(is_null), + IfThenElse(Not(alloc_size == 0), + IfThenElse(Call(DataType::Bool(), + builtin::isnullptr(), {vptr}), + call_err3), + Evaluate(0)), + Evaluate(0)), + nop})); + + // mark alignment of external bufs + init_nest_.emplace_back( + AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); + + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); + } + } +} + +} // namespace tl +} // namespace tvm \ No newline at end of file diff --git a/tilelang/original/src/transform/arg_binder.h b/tilelang/original/src/transform/arg_binder.h new file mode 100644 index 0000000000000000000000000000000000000000..bb7a0f46fd452dfe6cb44fa100ad9f94035a8719 --- /dev/null +++ b/tilelang/original/src/transform/arg_binder.h @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file arg_binder.h + * \brief Helper utility to match and bind arguments. + */ +#ifndef TVM_TL_TRANSFORM_ARG_BINDER_H_ +#define TVM_TL_TRANSFORM_ARG_BINDER_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Helper utility to generate match and bind of arguments. + * + * \note There is many places in TVM IR where we need argument bindings. + * + * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)). + * Here n is a undefined variable that is decided by the outside, tB imposes + * a constraint such that it can only take tensor with shape 3, tC imposes + * another constraint that it's shape must equals n + 2. + * So if we call it with f(bufferA, bufferB, bufferC), we need to generate + * the following binding sequence: + * - define n = bufferA.shape[0] + * - assert bufferB.shape[0] == 3 + * - assert bufferB.shape[1] == n + 3 + * + * In general, this is a constraint solving problem. We have simplified + * assumption over the binding declaration, such that we require the variable + * occurred in constraint must be declared in argument list. So it is illegal to + * have signature f(tA(shape=(n+3))) without any argument variable corresponds + * to n, even though it is already enough to derive n from the input argument. + */ +class ArgBinder { +public: + /*! + * \brief Constructor + * \param def_map A definition map that contains definition of known + * variables. ArgBinder will update this def_map when adding new definitions. + */ + explicit ArgBinder(std::unordered_map *def_map) + : def_map_(def_map) {} + /*! + * \brief Try to bind arg to value, generate constraint if necessary. + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + * \param with_let Whether add lets during bind + */ + void Bind(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_let = false); + /*! + * \brief Bind array to array + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + */ + void BindArray(const ffi::Array &arg, + const ffi::Array &value, + const std::string &arg_name); + /*! + * \brief Bind symbolic buffer to another symbolic buffer + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + * \param fuzzy_match If enabled, we allow value's dimension to be smaller + * than arg, as long as arg's higher dimensions are of 1. + */ + void BindBuffer(const Buffer &arg, const Buffer &value, + const std::string &arg_name, bool fuzzy_match); + + /*! + * \brief Bind symbolic buffer to a DLTensor handle. + * \param buffer The argument buffer to be binded. + * \param device_type The device type to be binded. + * \param device_id The device id to be binded. + * \param buffer_def The buffer definition. + * \param func_name The function name. + * \param used_param_buffers The used param buffers. + */ + void + BindDLTensors(const std::vector> &buffer_def, + const PrimExpr &device_type, const PrimExpr &device_id, + const std::string &func_name, + const std::unordered_set &used_param_buffers); + + /*! \return The defs generated in binding. */ + const std::vector &defs() const { return defs_; } + + /*! \return The asserts generated in binding + * + * This contains statements that assert the correct value has been + * bound. For example, `binder.Bind(var, expr_1)` will produce an + * entry mapping `var` to `expr_1` in the `binder.defs()`. If + * `binder.Bind(var, expr_2)` is called later, then this will + * produce an assert statemtn that `expr_1 == expr_2`. + * + * Note: Some assert statements produced by BindDLTensor are located + * in `binder.init_nest()`, not within `binder.asserts()`. This is + * deliberate, as some values may require checks prior to + * initialization. (e.g. Intializing `m = dl_tensor->shape[3]` + * requires first asserting that `3 < dl_tensor->ndim`.) + */ + const std::vector &asserts() const { return asserts_; } + + /*! + * \brief Initialization nest generated + * + * This contains both variable bindings and any assert statements + * that are required in order to safely produce those variable + * bindings. + * + * \note Variable bindings may be implemented either as a `LetStmt` + * that defines the variable, or as a variable replacement. Any + * bindings implemented as a `LetStmt` will be in the + * initialization list. Any bindings implemented as a variable + * replacement will be stored in the `var_def` map. + * + * A `tir::LetStmt` is usually generated when binding to a + * `DLTensor`. This requires loading values from memory, which + * should only be performed once. If the binding to a + * `DLTensor` were implemented as a variable replacement, it + * would load values from memory once for each usage of the + * variable. + * + * \return The initialization nest generated during binding. + */ + const std::vector &init_nest() const { return init_nest_; } + /*! \return Handle data type of the data */ + const ffi::Map &def_handle_dtype() const { + return def_handle_dtype_; + } + + bool BindNullable(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets, + const PrimExpr &nullable_guard); + +private: + std::vector getUndefVars(const std::vector &arg); + // Internal bind function + bool Bind_(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets); + /*! \brief The definition map, can be uses to substitute */ + std::unordered_map *def_map_; + /*! \brief defs generated in the current binder */ + std::vector defs_; + /*! \brief Initialize nest */ + std::vector init_nest_; + /*! \brief handle data type in the defintiions */ + ffi::Map def_handle_dtype_; + /*! \brief asserts generated */ + std::vector asserts_; + /*! \brief internal analyzer. */ + arith::Analyzer analyzer_; +}; +} // namespace tl +} // namespace tvm +#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_ \ No newline at end of file diff --git a/tilelang/original/src/transform/atomicadd_vectorize.cc b/tilelang/original/src/transform/atomicadd_vectorize.cc new file mode 100644 index 0000000000000000000000000000000000000000..d66a538dbe3560dc8ffbbf414b1da8c3886c266f --- /dev/null +++ b/tilelang/original/src/transform/atomicadd_vectorize.cc @@ -0,0 +1,308 @@ +/*! + * \file atomicadd_vectorize.cc + * \brief A tool to automatically vectorize atomic add + */ + +#include "atomicadd_vectorize.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default; + +AtomicAddVectorizePlanResult +AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) { + int vectorize_size_max = 1; + this->vector_size_ = 4; + this->dynamic_ = false; + this->condition_ = PrimExpr(); + + PostOrderVisit(node, [&](const ObjectRef &obj) { + if (const auto *call = obj.as()) { + if (call->op == atomicadd_elem_op()) { + if (call->args.size() < 2) { + // Fallback: unexpected arity + vectorize_size_max = 1; + DLOG(WARNING) << "[AtomicAddVectorizePlanner] atomicadd_elem_op " + "expects 2 args, got " + << call->args.size() << "; Fallback to no vectorize"; + return; + } + DataType dtype; + if (const auto *load = call->args[0].as()) { + dtype = load->dtype; + vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + } else if (const auto *ite = call->args[0].as()) { + if (const auto *then_load = ite->then_case.as()) { + dtype = then_load->dtype; + vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + } else if (const auto *else_load = + ite->else_case.as()) { + dtype = else_load->dtype; + vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + } else { + // fallback + vectorize_size_max = 1; + DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case " + "has no BufferLoad; Fallback to no vectorize"; + } + } else { + // fallback + vectorize_size_max = 1; + DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type " + << call->args[1]->GetTypeKey() + << "; Fallback to no vectorize"; + } + } + } + }); + + if (vectorize_size_max <= 1) { + return {1, dynamic_, condition_}; + } + + this->max_vector_size = vectorize_size_max; + this->operator()(node); + return {vector_size_, dynamic_, condition_}; +} + +void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) { + inner_for_ = node; + arith::IRVisitorWithAnalyzer::VisitStmt_(node); +} + +void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) { + if (node->op == atomicadd_elem_op() && !node->args.empty()) { + if (node->args.size() < 2) { + return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + } + const BufferLoadNode *buffer_load_dst = node->args[0].as(); + const BufferLoadNode *buffer_load_src = node->args[1].as(); + if (buffer_load_src && buffer_load_src->buffer.defined() && + buffer_load_dst && buffer_load_dst->buffer.defined()) { + Buffer dst_buffer = buffer_load_dst->buffer; + UpdateVectorSize(buffer_load_dst->indices, dst_buffer); + + Buffer src_buffer = buffer_load_src->buffer; + UpdateVectorSize(buffer_load_src->indices, src_buffer); + } + } + return arith::IRVisitorWithAnalyzer::VisitExpr_(node); +} + +int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability, + DataType dtype) { + if (dtype == DataType::Float(16)) { + return 2; + } + if (dtype == DataType::BFloat(16)) { + return compute_capability > 75 ? 2 : 1; + } + if (dtype == DataType::Float(32)) { + return compute_capability >= 90 ? 4 : 1; + } + return 1; +} + +void AtomicAddVectorizePlanner::UpdateVectorSize(const Array &indices, + const Buffer &buffer) { + if (!inner_for_) + return; + auto extent_ptr = inner_for_->extent.as(); + if (!extent_ptr) + return; + + const DataType &access_type = buffer->dtype; + max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); + + auto last_dim = buffer->shape.back(); + auto mod_set = analyzer_.modular_set(last_dim); + + if (buffer->shape.back().as()) { + max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); + auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base); + + if (gcd_base < Downcast(last_dim)->value) { + max_vector_size = gcd_base; + } + + vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); + + PrimExpr elem_offset = 0; + PrimExpr stride = 1; + for (int i = indices.size() - 1; i >= 0; --i) { + elem_offset = elem_offset + indices[i] * stride; + stride = stride * buffer->shape[i]; + } + + while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, vector_size_, &analyzer_)) { + vector_size_ /= 2; + } + } else if (vector_size_ <= 4) { + dynamic_ = true; + PrimExpr offset = buffer.OffsetOf(indices).back(); + condition_ = (truncmod(offset, vector_size_) == 0); + } +} + +class AtomicAddVectorizeRewriter : public StmtExprMutator { +public: + AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan) + : vector_size_(plan.vector_size), dynamic_(plan.dynamic), + condition_(plan.condition) {} + +private: + /** + * @brief Visits a For node and rewrites the innermost loop for atomic-add + * vectorization. + * + * If the visited For node is the recorded innermost loop, this method + * validates that the loop extent is a constant, divisible by the planned + * vector size, and has a zero minimum. When vectorization is enabled + * (dynamic_ == false) it: + * - locates the thread index variable named "tx" inside the loop body, + * - creates a new outer loop variable named "_outer", + * - substitutes occurrences of `tx` with `tx * vector_size_` and the old + * loop var with `outer_var * vector_size_` so each outer iteration maps to a + * contiguous vector-sized chunk, + * - returns a new For with extent divided by vector_size_ and the + * transformed body. + * + * If dynamic_ is true, the method returns the (possibly mutated) inner For + * unchanged. + * + * Side effects: + * - updates inner_for_ to point to the current For node during visitation. + * - performs runtime checks (ICHECK) to enforce: constant extent, extent % + * vector_size_ == 0, and zero loop minimum; violations terminate execution. + * + * @return The original or transformed For statement as a Stmt. + */ + Stmt VisitStmt_(const ForNode *node) final { + inner_for_ = node; + auto ret = StmtExprMutator::VisitStmt_(node); + if (vector_size_ == 1) + return ret; + if (inner_for_ == node) { + For fnode = ret.as().value(); + auto old_var = fnode->loop_var; + auto new_var = Var(old_var->name_hint); + auto extent_ptr = as_const_int(fnode->extent); + ICHECK(extent_ptr) << fnode->extent; + int extent = *extent_ptr; + ICHECK(extent % vector_size_ == 0) + << "extent: " << extent << " vector_size_: " << vector_size_; + ICHECK(is_zero(fnode->min)); + if (!dynamic_) { + Map vmap; + vmap.Set(old_var, new_var * vector_size_); + Stmt body = Substitute(fnode->body, vmap); + return For(new_var, 0, extent / vector_size_, fnode->kind, body, + fnode->thread_binding, fnode->annotations, fnode->step, + fnode->span); + } + } + return ret; + } + + PrimExpr VisitExpr_(const CallNode *node) final { + bool legal_vectorize = true; + if (dynamic_) + legal_vectorize = false; + if (!(node->op == atomicadd_elem_op())) + legal_vectorize = false; + if (node->args.size() < 2) + legal_vectorize = false; + if (legal_vectorize) { + const BufferLoadNode *temp_dst_node = node->args[0].as(); + const BufferLoadNode *temp_value_node = + node->args[1].as(); + if (!temp_dst_node || !temp_value_node) + legal_vectorize = false; + } + if (legal_vectorize) { + const BufferLoad dst_node = Downcast(node->args[0]); + const BufferLoad value_node = Downcast(node->args[1]); + // The default memory order is relaxed + // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd + const IntImm memory_order = + node->args.size() >= 3 ? Downcast(node->args[2]) : IntImm(0); + Array new_args; + Call address_of_dst = + Call(DataType::Handle(), builtin::address_of(), {dst_node}); + Call address_of_value = + Call(DataType::Handle(), builtin::address_of(), {value_node}); + if (vector_size_ == 4) { + new_args.push_back(StringImm("AtomicAddx4")); + new_args.push_back(address_of_dst); + new_args.push_back(address_of_value); + } else if (vector_size_ == 2) { + new_args.push_back(StringImm("AtomicAddx2")); + new_args.push_back(address_of_dst); + new_args.push_back(address_of_value); + } else { + // Scalar case: AtomicAdd now expects a pointer to destination. + new_args.push_back(StringImm("AtomicAdd")); + new_args.push_back(address_of_dst); + new_args.push_back(value_node); + } + new_args.push_back(memory_order); + + Call new_call = + tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); + + return new_call; + } else { + Array new_args; + new_args.push_back(StringImm("AtomicAdd")); + // Ensure first argument is an address; keep value as-is. + if (!node->args.empty()) { + if (const auto *bl = node->args[0].as()) { + Call address_of_dst = Call(DataType::Handle(), builtin::address_of(), + {Downcast(node->args[0])}); + new_args.push_back(address_of_dst); + } else if (const auto *call = node->args[0].as()) { + // If it's already an address_of, forward it; otherwise, keep + // original. + if (call->op.same_as(builtin::address_of())) { + new_args.push_back(node->args[0]); + } else { + new_args.push_back(node->args[0]); + } + } else { + new_args.push_back(node->args[0]); + } + // Push remaining args unchanged (value, optional memory_order, ...) + for (size_t i = 1; i < node->args.size(); ++i) { + new_args.push_back(node->args[i]); + } + } + + Call new_call = + tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); + + return new_call; + } + } + + const ForNode *inner_for_; + const int vector_size_; + const PrimExpr condition_; + const bool dynamic_; +}; + +For VectorizeAtomicAdd(const For &for_node, int compute_capability) { + AtomicAddVectorizePlanResult res = {1, false, 0}; + AtomicAddVectorizePlanner planner; + res = planner.Plan(for_node, compute_capability); + auto rewriter = AtomicAddVectorizeRewriter(res); + return Downcast(rewriter(for_node)); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/atomicadd_vectorize.h b/tilelang/original/src/transform/atomicadd_vectorize.h new file mode 100644 index 0000000000000000000000000000000000000000..627dc895f403c27c233061f1cfc9d20461801eca --- /dev/null +++ b/tilelang/original/src/transform/atomicadd_vectorize.h @@ -0,0 +1,60 @@ +/*! + * \file atomicadd_vectorize.h + * \brief A tool to automatically vectorize a for atomicadd + */ + +#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_ +#define TVM_TL_ATOMICADD_VECTORIZE_H_ + +#include "../layout/layout.h" +#include "../layout/utils.h" +#include "../op/builtin.h" +#include "arith/int_operator.h" +#include "arith/ir_visitor_with_analyzer.h" +#include "common/loop_vectorization_utils.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +For VectorizeAtomicAdd(const For &for_node, int compute_capability); + +struct AtomicAddVectorizePlanResult { + int vector_size; + bool dynamic; + PrimExpr condition; +}; + +class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer { +public: + AtomicAddVectorizePlanner(); + + AtomicAddVectorizePlanResult Plan(const For &node, int compute_capability); + +private: + void VisitStmt_(const ForNode *node) final; + void VisitExpr_(const CallNode *node) final; + + int GetVectorizeSizeMax(int compute_capability, DataType dtype); + void UpdateVectorSize(const Array &indices, const Buffer &buffer); + + const ForNode *inner_for_ = nullptr; + bool has_nonlocal_memory_access_ = false; + int vector_size_ = 4; + int max_vector_size = 1; + bool dynamic_ = false; + PrimExpr condition_; +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_ATOMICADD_VECTORIZE_H_ \ No newline at end of file diff --git a/tilelang/original/src/transform/cluster_planning.cc b/tilelang/original/src/transform/cluster_planning.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fcdc1691f925d4e3715269a5118716d104eb156 --- /dev/null +++ b/tilelang/original/src/transform/cluster_planning.cc @@ -0,0 +1,135 @@ +/*! + * \file clasuter_planning.cc + * \brief Plan the cluster for GPU(sm90+) blocks + */ + +#include +#include +#include +#include +#include +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace tir { + +class ClusterPlanner { +public: + static PrimFunc Substitute(PrimFunc &f) { + // Step 1: Collect the read region of the function + Map buffer_data_to_buffer_; + for (const auto &[_, buffer] : f->buffer_map) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ f->body); + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto reads = access[0]; + + BlockIdxVisitor blockIdx_visitor; + blockIdx_visitor(f->body); + auto dom_map = blockIdx_visitor.dom_map_; + + // Step 2: Collect mem reuse count for clustering on each dimension. + std::unordered_map mem_reuse_count; + for (auto iv : dom_map) + mem_reuse_count[iv] = 0; + + for (const auto &buffer_region : reads) { + PrimExpr size = buffer_region->buffer->dtype.bits(); + RegionVisitor visitor; + for (const auto &range : buffer_region->region) { + size = size * range->extent; + visitor(range->min); + } + size = arith::Analyzer().Simplify(size); + if (auto imm = size.as()) { + for (auto iv : dom_map) { + if (visitor.seen_.count(iv->var.get()) == 0) + mem_reuse_count[iv] += imm->value; + } + } + } + + // Step 3: Pick the cluster dimension with the largest mem_reuse. + size_t mem_reuse_max = 0; + String cluster_tag; + for (auto iv : dom_map) { + if (auto extent = iv->dom->extent.as()) { + if (extent->value % cluster_size_ == 0 && + mem_reuse_count[iv] > mem_reuse_max) { + cluster_tag = iv->thread_tag; + mem_reuse_max = mem_reuse_count[iv]; + } + } + } + + if (mem_reuse_max > 0) { + std::string tag_str = + static_cast(cluster_tag); // Convert to std::string + if (tag_str.rfind("blockIdx", 0) == 0) { + // starts with "blockIdx" + tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx")); + } else { + // Unexpected format — maybe just prefix + tag_str = "clusterIdx" + tag_str; + } + cluster_tag = String(tag_str); // Convert back + return WithAttr(f, cluster_tag, Integer(cluster_size_)); + } else { + return f; + } + } + +private: + ClusterPlanner() = default; + + class RegionVisitor : public ExprVisitor { + public: + RegionVisitor() {}; + void VisitExpr_(const VarNode *var) { seen_.insert(var); } + std::unordered_set seen_; + }; + + class BlockIdxVisitor : public StmtVisitor { + public: + BlockIdxVisitor() {}; + void VisitStmt_(const AttrStmtNode *attr) final { + if (attr->attr_key == attr::thread_extent) { + IterVar iv = Downcast(attr->node); + String tag = iv->thread_tag; + if (tag == "blockIdx.x" || tag == "blockIdx.y" || tag == "blockIdx.z") + dom_map_.insert(iv.get()); + } + StmtVisitor::VisitStmt_(attr); + } + /*! \brief The map from vars to blockidx extents. */ + std::unordered_set dom_map_; + }; + + /*! \brief Currently set the plossible cluster size as 2 */ + const static int cluster_size_ = 2; +}; + +PrimFunc ClusterPlanning(PrimFunc f) { return ClusterPlanner::Substitute(f); } + +namespace transform { + +tvm::transform::Pass ClusterPlanning() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return ClusterPlanning(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning); +} +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tilelang/original/src/transform/common/assume.cc b/tilelang/original/src/transform/common/assume.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb51d0f8a043db73f638572720bd7b71e15d79cb --- /dev/null +++ b/tilelang/original/src/transform/common/assume.cc @@ -0,0 +1,33 @@ + +/*! + * \file assume.cc + * \brief Utils on assume statements + */ + +#include "assume.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/expr.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +std::optional GetAssumeExprInEvaluateForm(Stmt stmt) { + auto eval = stmt.as(); + if (!eval) + return std::nullopt; + auto call = eval->value.as(); + if (!call) + return std::nullopt; + if (!call->op.same_as(builtin::assume())) + return std::nullopt; + return call->args[0]; +} + +bool IsAssumeInEvaluateForm(const Stmt &stmt) { + return GetAssumeExprInEvaluateForm(stmt).has_value(); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/common/assume.h b/tilelang/original/src/transform/common/assume.h new file mode 100644 index 0000000000000000000000000000000000000000..c6eadc6b341ef18d8062412350c52e007551cbee --- /dev/null +++ b/tilelang/original/src/transform/common/assume.h @@ -0,0 +1,28 @@ + +/*! + * \file assume.h + * \brief Utils on assume statements + */ + +#ifndef TVM_TL_TRANSFORM_COMMON_ASSUME_H_ +#define TVM_TL_TRANSFORM_COMMON_ASSUME_H_ + +#include "tvm/tir/stmt.h" +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// Get the expression inside an assume statement, if any. Returns nullopt if +// the statement is not an assume statement. +std::optional GetAssumeExprInEvaluateForm(Stmt stmt); + +// Check if a statement is an assume statement. +bool IsAssumeInEvaluateForm(const Stmt &stmt); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_ASSUME_H_ \ No newline at end of file diff --git a/tilelang/original/src/transform/common/attr.h b/tilelang/original/src/transform/common/attr.h new file mode 100644 index 0000000000000000000000000000000000000000..d71ee67231f6037d17ea8071fb21a477ea2628da --- /dev/null +++ b/tilelang/original/src/transform/common/attr.h @@ -0,0 +1,15 @@ +/*! + * \file attr.h + * \brief Check attributes of the IR + */ + +namespace tvm { +namespace tl { + +constexpr const char *MainBlockName = "tilelang_root"; + +constexpr const char *tilelang_is_cpu_kernel_frame = + "tilelang.is_cpu_kernel_frame"; + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/common/collector.h b/tilelang/original/src/transform/common/collector.h new file mode 100644 index 0000000000000000000000000000000000000000..28227703ddf39d3db4b889f17f8e98d10981f786 --- /dev/null +++ b/tilelang/original/src/transform/common/collector.h @@ -0,0 +1,74 @@ +/*! + * \file collector.h + * \brief Collect information from the IR + */ + +#include "arith/ir_visitor_with_analyzer.h" +#include "tir/analysis/var_use_def_analysis.h" +#include +#include +#include +#include +#include + +#include "../../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class ThreadTagChecker : public StmtExprVisitor { +public: + static bool HasOnlyThreadIdxX(const PrimFunc &f) { + ThreadTagChecker checker; + checker(f->body); + return checker.is_valid_; + } + + static IterVar GetThreadVar(const Stmt &body) { + ThreadTagChecker checker; + checker(body); + return checker.thread_var_; + } + +private: + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iter_var = Downcast(op->node); + String thread_tag = iter_var->thread_tag; + bool is_y_or_z = + thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z"; + + if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) { + is_valid_ = false; + } + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const ForNode *op) final { + if (op->kind == ForKind::kThreadBinding) { + ICHECK(op->thread_binding.defined()); + String thread_tag = op->thread_binding.value()->thread_tag; + if (thread_tag == "threadIdx.x") { + thread_var_ = Downcast(op->thread_binding); + } + bool is_y_or_z = + thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z"; + if (!thread_tag.empty() && is_y_or_z) { + auto iter_var = Downcast(op->thread_binding); + if (iter_var.defined() && iter_var->dom.defined() && + !is_one(iter_var->dom->extent)) { + is_valid_ = false; + } + } + } + StmtExprVisitor::VisitStmt_(op); + } + IterVar thread_var_; + bool is_valid_ = true; +}; + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/common/loop_fusion_utils.h b/tilelang/original/src/transform/common/loop_fusion_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..2fa6cdede4fcdc7827be31a1bd2f003172b1d021 --- /dev/null +++ b/tilelang/original/src/transform/common/loop_fusion_utils.h @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file common.h + * \brief Common utilities for TL transforms + */ + +#include +#include +#include +#include +#include + +#include + +#include "../../op/parallel.h" +#include "../loop_partition.h" +#include "../loop_vectorize.h" +#include "arith/ir_mutator_with_analyzer.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; + +class FragmentAccessDetector : public StmtExprVisitor { +public: + FragmentAccessDetector() = default; + + void Collect(const Stmt &stmt) { VisitStmt(stmt); } + + bool HasFragmentAccess() { return has_fragment_access_; } + +private: + void VisitExpr_(const BufferLoadNode *op) final { + // Check if the buffer is in global scope + if (IsFragmentBuffer(op->buffer)) { + has_fragment_access_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + // Check if the buffer is in global scope + if (IsFragmentBuffer(op->buffer)) { + has_fragment_access_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + + // Helper function to determine if a buffer is local.fragment + bool IsFragmentBuffer(const Buffer &buffer) { + // The storage scope is often encoded in the buffer->data var name or + // associated attributes. + String scope = buffer.scope(); + return scope == "local.fragment"; + } + + bool has_fragment_access_{false}; +}; + +/*! + * \brief ParallelLoopFuser + * This class is used to fuse a chain of parallel loops into one loop. + * The loops must: + * - All be parallel (ForKind::kParallel) + * - Have bounds from 0 to their extent + * Once fused, a single loop variable will replace the chain, and the + * original loop variables will be derived by division and modulo operations. + * + * This can be helpful for inferring layout for the fragment in a subsequent + * pass. + */ +class ParallelLoopFuser : public IRMutatorWithAnalyzer { +public: + static Stmt Fuse(const Stmt &stmt) { + arith::Analyzer analyzer; + ParallelLoopFuser substituter(&analyzer); + return substituter.VisitStmt(stmt); + } + +private: + ParallelLoopFuser(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {}; + + Stmt VisitStmt_(const ForNode *op) final { + // Gather consecutive parallel loops + std::vector loop_chain; + const ForNode *current = op; + // check if has fragment access + FragmentAccessDetector detector; + detector.Collect(op->body); + // Do not fuse if there is a fragment access + if (detector.HasFragmentAccess()) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + while (true) { + if (current->kind != ForKind::kParallel) + break; + if (!is_zero(current->min)) + break; + loop_chain.push_back(current); + + const ForNode *inner_for = current->body.as(); + if (!inner_for) { + break; + } + current = inner_for; + } + + // If only one loop found or loop chain size is 1, no fusion needed. + if (loop_chain.size() <= 1) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + // If one of the loop has extent which is not 2^n, we do not fuse + for (auto l : loop_chain) { + PrimExpr extent = l->extent; + // If extent is not a constant integer, we cannot determine if it's power + // of 2 + if (!extent.as()) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + int64_t value = extent.as()->value; + // Check if value is power of 2: value > 0 and only has one bit set + if (value <= 0 || (value & (value - 1)) != 0) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + } + + // At this point we have multiple nested parallel loops starting at zero + // We will fuse them all. + PrimExpr fused_extent = make_const(DataType::Int(32), 1); + for (auto it = loop_chain.rbegin(); it != loop_chain.rend(); ++it) { + fused_extent = fused_extent * (*it)->extent; + } + + std::string fused_name; + for (auto it = loop_chain.begin(); it != loop_chain.end(); ++it) { + fused_name += (*it)->loop_var->name_hint + "_"; + } + + fused_name += "fused"; + + // Create a new fused loop var + Var fused_var(fused_name, DataType::Int(32)); + + // The body of the last loop in the chain: + const ForNode *innermost_loop = loop_chain.back(); + Stmt body = innermost_loop->body; + + // We need to substitute all loop variables in the chain. + // The scheme: + // Suppose we have loops (i in [0,M], j in [0,N], k in [0,O]) + // fused loop var f in [0, M*N*O] + // i = f / (N*O) + // j = (f % (N*O)) / O + // k = f % O + // + // Generalizing for a chain of lengths L: + // extents: E_0, E_1, ... E_{L-1} + // index_i = (f / (E_{i+1}*...*E_{L-1})) % E_i + // For the last one, it's just f % E_{L-1} if i == L-1. + + // Compute the "stride" products for each loop variable + // stride[i] = product of extents of loops after i + // for L loops: stride[L-1] = 1 + // stride[L-2] = E_{L-1} + // stride[L-3] = E_{L-1} * E_{L-2} + // ... + std::vector extents; + extents.reserve(loop_chain.size()); + for (auto l : loop_chain) { + extents.push_back(l->extent); + } + + std::vector strides(loop_chain.size(), + make_const(DataType::Int(32), 1)); + for (int i = static_cast(loop_chain.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * extents[i + 1]; + } + + // We'll create a substitution map for all loop variables + // index_i = (f / strides[i]) % extents[i] + // We'll define a helper lambda: + auto create_index_expr = [&](int i) { + return FloorMod(FloorDiv(fused_var, strides[i]), extents[i]); + }; + + Map var_map; + for (size_t i = 0; i < loop_chain.size(); i++) { + const ForNode *loop = loop_chain[i]; + var_map.Set(loop->loop_var, + analyzer_->Simplify(create_index_expr(static_cast(i)))); + } + + // Perform the substitution + body = Substitute(body, var_map); + + // Create the fused loop + For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body); + fused_for.CopyOnWrite()->annotations = op->annotations; + return fused_for; + } +}; + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/common/loop_parallel_transform_utils.h b/tilelang/original/src/transform/common/loop_parallel_transform_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1e8d7a35042a2d7a1916adb1f8dd1a5da5819e9b --- /dev/null +++ b/tilelang/original/src/transform/common/loop_parallel_transform_utils.h @@ -0,0 +1,168 @@ +/*! + * \file common.h + * \brief Common utilities for TL transforms + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" +#include + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +class ParallelLoopTransformer : public IRMutatorWithAnalyzer { +public: + static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) { + arith::Analyzer analyzer; + ParallelLoopTransformer transformer(&analyzer); + return transformer.VisitStmt(stmt); + } + + ParallelLoopTransformer(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {} + + Stmt VisitStmt_(const ForNode *op) final { + + if (op->kind != ForKind::kParallel) + return StmtMutator::VisitStmt_(op); + + // Collect loop variables and ranges + auto for_node = tvm::ffi::GetRef(op); + Array loop_vars; + Array loop_extents; + Stmt body = op->body; + + // Bind the range of outer loop variables + analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent)); + loop_vars.push_back(op->loop_var); + loop_extents.push_back(op->extent); + + // If there are inner loops, bind their ranges as well + while (const ForNode *inner = body.as()) { + analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent)); + loop_vars.push_back(inner->loop_var); + loop_extents.push_back(inner->extent); + body = inner->body; + } + + ICHECK(loop_vars.size() == loop_extents.size()) + << "loop_vars and loop_extents size mismatch"; + + // Collect buffer access information + BufferAccessCollector collector; + collector(op->body); + + PrimExpr condition; + + for (const auto &[buffer, indices] : collector.buffer_indices) { + ICHECK(indices.size() == buffer->shape.size()) + << "indices size mismatch with buffer shape"; + + for (size_t i = 0; i < indices.size(); ++i) { + auto index = indices[i]; + auto bound = analyzer_->const_int_bound(index); + + // Collect the variables that used in the index + std::unordered_set used_vars; + // post order visit the index + PostOrderVisit(index, [&](const ObjectRef &obj) { + if (const VarNode *v = obj.as()) { + used_vars.insert(tvm::ffi::GetRef(v)); + } + }); + if (used_vars.empty()) { + continue; + } + + // find related loop vars + Array related_loop_vars; + for (size_t j = 0; j < loop_vars.size(); ++j) { + auto loop_var = loop_vars[j]; + // if find related, pop the loop_vars and loop_extents + if (used_vars.count(loop_var)) { + related_loop_vars.push_back(loop_var); + } + if (related_loop_vars.size() > 1) { + // Only one related loop var is supported transformation currently. + return for_node; + } + + auto bound = analyzer_->const_int_bound(index); + int64_t upper_bound = bound->max_value + 1; + int64_t shape = Downcast(buffer->shape[i])->value; + if (upper_bound < shape) { + PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound)); + condition = + condition.defined() ? And(condition, predicate) : predicate; + } + } + } + } + + if (condition.defined()) { + body = IfThenElse(condition, body); + + for (int j = loop_vars.size() - 1; j >= 0; --j) { + auto loop_var = loop_vars[j]; + auto loop_extent = loop_extents[j]; + body = For(loop_var, 0, loop_extent, ForKind::kParallel, body); + } + + return Downcast(body); + } + + // Only traverse the outer loop + return for_node; + } + + // Helper class for collecting buffer access information, only counts fragment + // buffer access + class BufferAccessCollector : public StmtExprVisitor { + public: + void VisitExpr_(const BufferLoadNode *op) final { + if (op->buffer.scope() == "local.fragment") { + if (buffer_indices.find(op->buffer) == buffer_indices.end()) { + buffer_indices[op->buffer] = op->indices; + } else { + // check equal + ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices)) + << "indices mismatch for buffer: " << op->buffer; + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.scope() == "local.fragment") { + if (buffer_indices.find(op->buffer) == buffer_indices.end()) { + buffer_indices[op->buffer] = op->indices; + } else { + // check equal + ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices)) + << "indices mismatch for buffer: " << op->buffer; + } + } + StmtExprVisitor::VisitStmt_(op); + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_indices; + }; +}; + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/common/loop_vectorization_utils.h b/tilelang/original/src/transform/common/loop_vectorization_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b9b7715d0913f14619c72e17c367ed686a2b2baa --- /dev/null +++ b/tilelang/original/src/transform/common/loop_vectorization_utils.h @@ -0,0 +1,784 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file common.h + * \brief Common utilities for TL transforms + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "../../op/parallel.h" +#include "../loop_partition.h" +#include "../loop_vectorize.h" +#include "arith/ir_mutator_with_analyzer.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +// Vectorize Part +// Use the same code as tir.transform.vectorize_loop +inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { + if (is_scalable) { + return Mul(Call(DataType::Int(32), builtin::vscale(), {}), + lanes_or_vscale_factor); + } else { + return lanes_or_vscale_factor; + } +} + +inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { + // Check if e is already in the expected form + if (e.dtype().get_lanes_or_vscale_factor() == lanes && + e.dtype().is_scalable_vector() == is_scalable) + return e; + + if (const BroadcastNode *op = e.as()) { + ICHECK(op->dtype.is_scalable_vector() == is_scalable) + << "Can't broadcast between scalable and fixed length vectors."; + int e_lanes = op->dtype.get_lanes_or_vscale_factor(); + + if (lanes % e_lanes == 0) { + return Broadcast(op->value, CreateNewLanes(is_scalable, lanes)); + } + } + + ICHECK(e.dtype().is_scalar()) + << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor() + << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes; + + return Broadcast(e, CreateNewLanes(is_scalable, lanes)); +} + +// Rewrite vectorized allocation access +// This is necessary for making each vector component containing its own +// workspace. Originates from Halide's loop vectorizer +// +// s[i] = s[i * lanes + var] +// +// The same principle applies when using one thread to simulate multiple +// context. +// +class VecAllocAccess : public StmtExprMutator { +public: + VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes) + : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {} + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return UpdateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return UpdateBufferAccess(store); + } + +private: + template Node UpdateBufferAccess(Node node) { + // Only update the buffer that's being replaced. + if (node->buffer->data.get() != buf_) { + return node; + } + + // Find/make a Buffer object with the correct updated shape. + Buffer buf; + auto it = buffer_map_.find(node->buffer.get()); + if (it != buffer_map_.end()) { + buf = it->second; + } else { + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + Array shape = node->buffer->shape; + shape.Set(shape.size() - 1, + analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer, implement by appending a + // dimension to the buffer. Since it is currently after the + // flattening, the strides are not technically necessary, but + // are updated for consistency. + + // Update strides if defined. + Array strides; + for (size_t i = 0; i < strides.size(); i++) { + PrimExpr stride = strides[i]; + if (i != strides.size() - 1) { + stride *= var_lanes_; + } + strides.push_back(analyzer_.Simplify(stride)); + } + + // Copy everything into the new buffer. + buf = node->buffer; + auto buf_writer = buf.CopyOnWrite(); + buf_writer->shape = shape; + buf_writer->strides = strides; + buffer_map_[buf.get()] = buf; + } + + // Extend the last index by the number of lanes in the vectorized + // variable. + Array indices = node->indices; + indices.Set( + indices.size() - 1, + analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); + + auto writer = node.CopyOnWrite(); + writer->buffer = buf; + writer->indices = indices; + return node; + } + + // buffer var + const VarNode *buf_; + // Updated buffer objects. + std::unordered_map buffer_map_; + // variable to be replaced + Var var_; + // the lanes. + PrimExpr var_lanes_; + // Analyzer for simplifications + arith::Analyzer analyzer_; +}; + +// We use ExprFunctor directly instead of StmtExprMutator +// This is because the transformation can change the dtype of the Expr +// The existing ExprMutator transformation rules may not be well defined. +class Vectorizer : public StmtMutator, + public ExprFunctor { +public: + using ExprFunctor::VisitExpr; + using StmtMutator::operator(); + + Vectorizer(const Var &var, const PrimExpr &var_lanes) + : var_(var), var_lanes_(var_lanes) { + ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); + } + + Stmt VisitStmt(const Stmt &stmt) final { + ICHECK(!need_scalarize_); + Stmt ret = StmtMutator::VisitStmt(stmt); + if (need_scalarize_) { + need_scalarize_ = false; + return Scalarize(stmt); + } else { + return ret; + } + } + + PrimExpr VisitExpr(const PrimExpr &e) final { + return ExprFunctor::VisitExpr(e); + } + + PrimExpr VisitExpr_(const AddNode *op) final { + return AddSubVec( + op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); }); + } + + PrimExpr VisitExpr_(const SubNode *op) final { + return AddSubVec( + op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); }); + } + + PrimExpr VisitExpr_(const MulNode *op) final { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return tvm::ffi::GetRef(op); + } else { + bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); + bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); + if (is_vec_a && is_vec_b) { + // Let's not multiply scalable and fixed length vectors + ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector()) + << "Fixed length and scalable vectors can't be mixed in " + "multiplication."; + } + if (is_vec_a || is_vec_b) { + const RampNode *b_ramp = b.as(); + const RampNode *a_ramp = a.as(); + if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { + PrimExpr lanes = a_ramp->lanes; + return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); + } + if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) { + PrimExpr lanes = b_ramp->lanes; + return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); + } + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int max_lanes = std::max(a_lanes, b_lanes); + bool is_scalable = + a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return Mul(BroadcastTo(a, max_lanes, is_scalable), + BroadcastTo(b, max_lanes, is_scalable)); + } + } + return BinaryVec(op); + } + PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec
(op); } + PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode *op) final { + return BinaryVec(op); + } + PrimExpr VisitExpr_(const FloorModNode *op) final { + return BinaryVec(op); + } + PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec(op); } + + PrimExpr VisitExpr_(const NotNode *op) final { + PrimExpr a = this->VisitExpr(op->a); + if (a.same_as(op->a)) { + return tvm::ffi::GetRef(op); + } else { + return !(a); + } + } + + PrimExpr VisitExpr_(const RampNode *op) final { + PrimExpr base = this->VisitExpr(op->base); + PrimExpr stride = this->VisitExpr(op->stride); + ICHECK(!base.dtype().is_scalable_vector()) + << "Creating scalable vectors from existing vectors is not supported."; + ICHECK(!stride.dtype().is_scalable_vector()) + << "Ramp stride with scalable dtype is not supported"; + if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { + ICHECK(op->lanes->IsInstance()) + << "Vectorizing over existing scalable vectors is not supported."; + const RampNode *base_ramp = base.as(); + int op_lanes = static_cast(Downcast(op->lanes)->value); + int base_ramp_lanes = + static_cast(Downcast(base_ramp->lanes)->value); + if (analyzer_.CanProve(base_ramp->stride == + stride * + make_const(stride.dtype(), base_ramp_lanes))) { + return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); + } + } + int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); + base = BroadcastTo(base, lanes, false); + stride = BroadcastTo(stride, lanes, false); + Array elems; + for (int i = 0; i < lanes; ++i) { + elems.push_back(Ramp(Shuffle::ExtractElement(base, i), + Shuffle::ExtractElement(stride, i), op->lanes)); + } + return Shuffle::Concat(elems); + } + + PrimExpr VisitExpr_(const BroadcastNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.dtype().is_scalable_or_fixed_length_vector()) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + if (value.same_as(op->value)) { + return tvm::ffi::GetRef(op); + } else { + return Broadcast(op->value, op->lanes); + } + } + + PrimExpr VisitExpr_(const SelectNode *op) final { + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr t = this->VisitExpr(op->true_value); + PrimExpr f = this->VisitExpr(op->false_value); + if (cond.same_as(op->condition) && t.same_as(op->true_value) && + f.same_as(op->false_value)) { + return tvm::ffi::GetRef(op); + } else { + int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes); + bool is_scalable = cond.dtype().is_scalable_vector() || + t.dtype().is_scalable_vector() || + f.dtype().is_scalable_vector(); + return Select(BroadcastTo(cond, lanes, is_scalable), + BroadcastTo(t, lanes, is_scalable), + BroadcastTo(f, lanes, is_scalable)); + } + } + + PrimExpr VisitExpr_(const CastNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return tvm::ffi::GetRef(op); + } else { + if (value.dtype().is_scalable_vector()) { + return Cast(op->dtype.with_scalable_vscale_factor( + value.dtype().vscale_factor()), + value); + } else { + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + } + } + } + + PrimExpr VisitExpr_(const FloatImmNode *op) final { + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const IntImmNode *op) final { + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const StringImmNode *op) final { + return tvm::ffi::GetRef(op); + } + + // Variable + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = tvm::ffi::GetRef(op); + + if (var.same_as(var_)) { + return ramp_; + } + auto it = let_binding_.find(var); + if (it != let_binding_.end()) { + return it->second; + } else { + return std::move(var); + } + } + // IfThenElse expr + PrimExpr MutateIfThenElseExpr_(const CallNode *op) { + PrimExpr cond = this->VisitExpr(op->args[0]); + if (cond.dtype().is_scalable_or_fixed_length_vector()) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + PrimExpr t = this->VisitExpr(op->args[1]); + PrimExpr f = this->VisitExpr(op->args[2]); + if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && + f.same_as(op->args[2])) { + return tvm::ffi::GetRef(op); + } else { + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(t_lanes, f_lanes); + bool is_scalable = + t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); + t = BroadcastTo(t, lanes, is_scalable); + f = BroadcastTo(f, lanes, is_scalable); + if (is_scalable) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, + {cond, t, f}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + } + } + } + // Reinterpret expr + PrimExpr MutateReinterpretExpr_(const CallNode *op) { + ICHECK(op->op.same_as(builtin::reinterpret())); + PrimExpr value = this->VisitExpr(op->args[0]); + if (value.same_as(op->args[0])) { + return tvm::ffi::GetRef(op); + } else { + int lanes = value.dtype().get_lanes_or_vscale_factor(); + if (value.dtype().is_scalable_vector()) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, + {value}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {value}); + } + } + } + // Call + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::if_then_else())) { + return MutateIfThenElseExpr_(op); + } else if (op->op.same_as(builtin::texture2d_load())) { + int lane = 0; + Array fcd = MutateArray({op->args.back()}, &lane); + auto new_args = op->args; + new_args.pop_back(); + new_args.push_back(fcd[0]); + return Call(op->dtype.with_lanes(4), op->op, new_args); + } else if (op->op.same_as(builtin::texture2d_store())) { + int lane = 0; + // Vectorize the value to store + Array value{op->args.back()}; + Array mutated_value = MutateArray(value, &lane); + Array new_args{op->args[0], op->args[1], op->args[2], + mutated_value[0]}; + return Call(op->dtype.with_lanes(lane), op->op, new_args); + } else if (op->op.same_as(builtin::reinterpret())) { + return MutateReinterpretExpr_(op); + } + auto optional_op = op->op.as(); + bool vectorizable = optional_op && + op_vectorizable_.get(optional_op.value(), false) && + !op->dtype.is_scalable_vector(); + + if (!vectorizable) { + // Cannot vectorize this op + Array new_args; + for (auto arg : op->args) { + auto new_arg = this->VisitExpr(arg); + if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + new_args.push_back(new_arg); + } + if (op->args.same_as(new_args)) { + return tvm::ffi::GetRef(op); + } else { + return Call(op->dtype, op->op, new_args); + } + } else { + int lane = 0; + Array new_args = MutateArray(op->args, &lane); + // normal code path. + if (op->args.same_as(new_args)) { + return tvm::ffi::GetRef(op); + } else { + return Call(op->dtype.with_lanes(lane), op->op, new_args); + } + } + } + // BufferLoad + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = tvm::ffi::GetRef(op); + + auto fmutate = [this](const PrimExpr &index) { + return this->VisitExpr(index); + }; + Array indices = op->indices.Map(fmutate); + + if (!indices.same_as(op->indices)) { + auto writer = load.CopyOnWrite(); + writer->indices = indices; + } + + return std::move(load); + } + // Let + PrimExpr VisitExpr_(const LetNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to construct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var); + if (it != let_binding_.end()) { + ICHECK(deep_equal_(it->second, value)) + << "Let cannot bind the same var to two different values"; + } + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { + Var new_var(op->var->name_hint, value.dtype()); + let_binding_[op->var] = new_var; + return Let(new_var, value, this->VisitExpr(op->body)); + } else { + let_binding_[op->var] = op->var; + PrimExpr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + return Let(op->var, value, body); + } + } + } + // BufferStore + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = tvm::ffi::GetRef(op); + + auto fmutate = [this](const PrimExpr &index) { + return this->VisitExpr(index); + }; + Array indices = op->indices.Map(fmutate); + + PrimExpr value = this->VisitExpr(op->value); + + if (!indices.same_as(op->indices) || !value.same_as(op->value)) { + ICHECK(!op->buffer->dtype.is_scalable_vector()) + << "Vectorizing over scalable buffer elements is not supported in " + "vectorizer."; + // How many lanes of indexing are present in the index and + // buffer element type, excluding the last index. + int other_index_lanes = op->buffer->dtype.lanes(); + for (size_t i = 0; i < indices.size() - 1; i++) { + other_index_lanes *= indices[i].dtype().lanes(); + // Only allow the last index to be scalable + ICHECK(!indices[i].dtype().is_scalable_vector()) + << "Only the last index can be scalable."; + } + + // The total number of lanes of indexing, including the last index. + auto last_index_dtype = indices[indices.size() - 1].dtype(); + int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor(); + int index_lanes = other_index_lanes * lanes_in_last_index; + + // The total number of lanes in this store operation. Either + // the index or the value will be broadcast out to this number + // of lanes, depending on which has more lanes. + int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); + bool is_last_index_scalable = last_index_dtype.is_scalable_vector(); + int total_lanes = std::max(index_lanes, value_dtype_lanes); + + ICHECK_EQ(total_lanes % other_index_lanes, 0) + << "When storing to buffer " << op->buffer->name + << ", cannot produce " << total_lanes + << " lanes of storage location by changing the last index."; + int last_index_lanes = total_lanes / other_index_lanes; + + // Broadcast the last index such that the total number of index + // lanes matches the desired number. + indices.Set(indices.size() - 1, + BroadcastTo(indices[indices.size() - 1], last_index_lanes, + is_last_index_scalable)); + + auto writer = store.CopyOnWrite(); + writer->indices = indices; + writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); + } + + return std::move(store); + } + // For + Stmt VisitStmt_(const ForNode *op) final { + if (op->kind == ForKind::kVectorized) { + LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; + } + ICHECK(is_zero(op->min)); + ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); + PrimExpr extent = this->VisitExpr(op->extent); + if (extent.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(tvm::ffi::GetRef(op)); + } + Stmt body = this->VisitStmt(op->body); + if (extent.same_as(op->extent) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + return For(op->loop_var, op->min, extent, op->kind, body, + op->thread_binding, op->annotations); + } + } + // IfThenElse + Stmt VisitStmt_(const IfThenElseNode *op) final { + ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); + PrimExpr condition = this->VisitExpr(op->condition); + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(tvm::ffi::GetRef(op)); + } + Stmt then_case = this->VisitStmt(op->then_case); + Optional else_case = std::nullopt; + if (op->else_case) { + else_case = this->VisitStmt(op->else_case.value()); + } + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return tvm::ffi::GetRef(op); + } else { + return IfThenElse(condition, then_case, else_case); + } + } + // While + Stmt VisitStmt_(const WhileNode *op) final { + LOG(FATAL) << "A while loop inside a vectorized loop not supported."; + } + // LetStmt + Stmt VisitStmt_(const LetStmtNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + ICHECK(!let_binding_.count(op->var)) + << "SSA violation, a single var is binded twice"; + let_binding_[op->var] = value; + + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { + Var new_var(op->var->name_hint, value.dtype()); + let_binding_[op->var] = new_var; + return LetStmt(new_var, value, this->VisitStmt(op->body)); + } else { + let_binding_[op->var] = op->var; + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } + } + // Allocate + Stmt VisitStmt_(const AllocateNode *op) final { + // Mutate the condition + PrimExpr condition = this->VisitExpr(op->condition); + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc of " + << op->buffer_var->name_hint; + return Scalarize(tvm::ffi::GetRef(op)); + } + + // Mutate the extents + Array extents; + for (const auto &extent : op->extents) { + PrimExpr new_ext = this->VisitExpr(extent); + if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc of " + << op->buffer_var->name_hint; + return Scalarize(tvm::ffi::GetRef(op)); + } + extents.push_back(new_ext); + } + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer. That will allow this pass to be + // implemented as adding a new buffer dimension, which is later + // flattened. + + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); + + // Rewrite access to the buffer in the body. + Stmt body = + VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); + body = this->VisitStmt(body); + return Allocate(op->buffer_var, op->dtype, extents, condition, body); + } + + // scalarize the statement + Stmt Scalarize(Stmt stmt) { + Var idx(var_->name_hint + ".s", var_->dtype); + stmt = Substitute(stmt, {{var_, idx}}); + return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); + } + +private: + // analyzer + arith::Analyzer analyzer_; + // deep equal + ExprDeepEqual deep_equal_; + // variable to be replaced + Var var_; + // the lanes. + PrimExpr var_lanes_; + // ramp representing the var. + PrimExpr ramp_; + // flag to mark requirement of scalarization. + bool need_scalarize_{false}; + // Let binding + std::unordered_map let_binding_; + // vectorizable property + OpAttrMap op_vectorizable_ = + Op::GetAttrMap("TVectorizable"); + + // mutate array, with given lane requirement + // when finished, p_lane updates the lane requirement. + Array MutateArray(Array arr, int *p_lanes) { + if (arr.empty()) + return arr; + int &lanes = *p_lanes; + bool changed = false; + std::vector new_arr(arr.size()); + for (size_t i = 0; i < arr.size(); i++) { + PrimExpr old_elem = arr[i]; + PrimExpr new_elem = this->VisitExpr(old_elem); + if (!new_elem.same_as(old_elem)) + changed = true; + new_arr[i] = new_elem; + lanes = std::max(lanes, new_elem.dtype().lanes()); + } + + for (size_t i = 0; i < arr.size(); ++i) { + if (new_arr[i].dtype().lanes() != lanes) { + new_arr[i] = BroadcastTo(new_arr[i], lanes, false); + changed = true; + } + } + if (!changed) + return arr; + return Array(new_arr); + } + template PrimExpr BinaryVec(const T *op) { + static_assert(std::is_same::value, + "constraint"); + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return tvm::ffi::GetRef(op); + } else { + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); + bool is_scalable = + a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return TOp(BroadcastTo(a, lanes, is_scalable), + BroadcastTo(b, lanes, is_scalable)); + } + } + template + PrimExpr AddSubVec(const T *op, FCompute fcompute) { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return tvm::ffi::GetRef(op); + } else { + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); + if (lanes != 1) { + const RampNode *b_ramp = b.as(); + const RampNode *a_ramp = a.as(); + if (a.dtype().is_scalar() && b_ramp) { + return Ramp( + fcompute(a, b_ramp->base), + fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), + b_ramp->lanes); + } + if (b.dtype().is_scalar() && a_ramp) { + return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + } + } + bool is_scalable = + a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return fcompute(BroadcastTo(a, lanes, is_scalable), + BroadcastTo(b, lanes, is_scalable)); + } + } +}; + +} // namespace tl +} // namespace tvm \ No newline at end of file diff --git a/tilelang/original/src/transform/common/thread_sync_types.h b/tilelang/original/src/transform/common/thread_sync_types.h new file mode 100644 index 0000000000000000000000000000000000000000..bbcf4c2b4af7c4a322d68b950c53302aa740eb35 --- /dev/null +++ b/tilelang/original/src/transform/common/thread_sync_types.h @@ -0,0 +1,51 @@ +/*! + * \file thread_sync_types.h + */ +#ifndef TVM_TL_THREAD_BOUND_KEY_H_ +#define TVM_TL_THREAD_BOUND_KEY_H_ + +#include +#include + +namespace tvm { +namespace tl { + +struct ThreadBoundKey { + int64_t tx_min, tx_max, ty_min, ty_max, tz_min, tz_max; + bool operator==(const ThreadBoundKey &other) const { + return tx_min == other.tx_min && tx_max == other.tx_max && + ty_min == other.ty_min && ty_max == other.ty_max && + tz_min == other.tz_min && tz_max == other.tz_max; + } +}; + +// There are 16 Named Barriers provided by Hardware starting in Hopper +// Their IDs are in the range 0-15 +// Number of threads syncing using the barrier must be a multiple of warp-size +// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads) +// may use it and conflict with other uses. +enum class ReservedNamedBarriers : uint8_t { + kSyncThreads = 0, + kReduce_0 = 1, + kReduce_1 = 2, + kFirstUsedBarrier = kReduce_1 + 1 +}; + +} // namespace tl +} // namespace tvm + +namespace std { +template <> struct hash { + size_t operator()(const tvm::tl::ThreadBoundKey &k) const { + size_t h = std::hash()(k.tx_min); + h = h * 31 + std::hash()(k.tx_max); + h = h * 31 + std::hash()(k.ty_min); + h = h * 31 + std::hash()(k.ty_max); + h = h * 31 + std::hash()(k.tz_min); + h = h * 31 + std::hash()(k.tz_max); + return h; + } +}; +} // namespace std + +#endif // TVM_TL_THREAD_BOUND_KEY_H_ diff --git a/tilelang/original/src/transform/common/union_find.h b/tilelang/original/src/transform/common/union_find.h new file mode 100644 index 0000000000000000000000000000000000000000..75192ad37fdbcefc57c845ae1e121bec0742b18f --- /dev/null +++ b/tilelang/original/src/transform/common/union_find.h @@ -0,0 +1,52 @@ +#ifndef TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_ +#define TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_ + +#include +#include + +namespace tvm { +namespace tl { + +template class UnionFind { +public: + void MakeSet(const T &x) { + if (parent_.find(x) == parent_.end()) { + parent_[x] = x; + rank_[x] = 0; + } + } + + T Find(const T &x) { + if (parent_[x] != x) { + parent_[x] = Find(parent_[x]); // Path compression + } + return parent_[x]; + } + + void Union(const T &x, const T &y) { + T x_root = Find(x); + T y_root = Find(y); + + if (x_root == y_root) + return; + + // Union by rank + if (rank_[x_root] < rank_[y_root]) { + parent_[x_root] = y_root; + } else if (rank_[x_root] > rank_[y_root]) { + parent_[y_root] = x_root; + } else { + parent_[y_root] = x_root; + rank_[x_root]++; + } + } + +private: + std::unordered_map parent_; + std::unordered_map rank_; +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_ diff --git a/tilelang/original/src/transform/config_index_bitwidth.cc b/tilelang/original/src/transform/config_index_bitwidth.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0a577555ebdc23e122a81aac9ba3f184ee7281b --- /dev/null +++ b/tilelang/original/src/transform/config_index_bitwidth.cc @@ -0,0 +1,193 @@ +#include "../op/builtin.h" +#include "arith/ir_mutator_with_analyzer.h" +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace arith; +class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter { +public: + using Parent = IndexDataTypeRewriter; + ConfigIndexBitwidthRewriter(int index_bitwidth) + : _index_bitwidth_(index_bitwidth) {} + + Stmt operator()(const Stmt &s) { return VisitStmt(s); } + +protected: + using Parent::VisitExpr_; + using Parent::VisitStmt_; + + PrimExpr VisitExpr_(const VarNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + DataType new_dtype = DataType::Int(64); + if (!var_remap_.count(op)) { + var_remap_[op] = Var(op->name_hint, new_dtype); + } + } + return Parent::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const IntImmNode *op) final { + if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) { + return IntImm(DataType::Int(_index_bitwidth_), op->value); + } + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const CastNode *op) final { + if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) { + PrimExpr value = VisitExpr(op->value); + return Cast(DataType::Int(_index_bitwidth_), value); + } + return Parent::VisitExpr_(op); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + // Force indices to be int64 + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto node = Downcast(Parent::VisitStmt_(op)); + is_enabled_ = is_enabled; + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + // Force indices to be int64 + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto node = Downcast(Parent::VisitExpr_(op)); + is_enabled_ = is_enabled; + return std::move(node); + } + + int _index_bitwidth_; +}; + +class IndexLegalizer : public IRMutatorWithAnalyzer { + +public: + static Stmt Rewrite(const Stmt &stmt) { + Analyzer ana; + auto pass = IndexLegalizer(&ana); + return pass.VisitStmt(stmt); + } + +private: + explicit IndexLegalizer(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {} + + class Int64Promoter : public IndexDataTypeRewriter { + public: + using Parent = IndexDataTypeRewriter; + + PrimExpr VisitExpr_(const VarNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + return cast(DataType::Int(64), tvm::ffi::GetRef(op)); + } + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const IntImmNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + return IntImm(DataType::Int(64), op->value); + } + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const CastNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + return cast(DataType::Int(64), op->value); + } + return tvm::ffi::GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + // Force indices to be int64 + auto node = Downcast(Parent::VisitStmt_(op)); + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(Parent::VisitExpr_(op)); + return std::move(node); + } + }; + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto buffer_store = + Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + auto indices = buffer_store->indices; + Array new_indices; + for (auto index : indices) { + if (index->dtype.is_int() && index->dtype.bits() < 64) { + auto int_bound = analyzer_->const_int_bound(index); + if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 || + int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { + Int64Promoter promoter; + index = promoter(index); + new_indices.push_back(index); + continue; + } + } + new_indices.push_back(index); + } + buffer_store.CopyOnWrite()->indices = new_indices; + return std::move(buffer_store); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto buffer_load = + Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + auto indices = buffer_load->indices; + Array new_indices; + for (auto index : indices) { + if (index->dtype.is_int() && index->dtype.bits() < 64) { + auto int_bound = analyzer_->const_int_bound(index); + if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 || + int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { + Int64Promoter promoter; + index = promoter(index); + new_indices.push_back(index); + continue; + } + } + new_indices.push_back(index); + } + buffer_load.CopyOnWrite()->indices = new_indices; + return std::move(buffer_load); + } +}; + +tvm::transform::Pass ConfigIndexBitwidth() { + using namespace tir::transform; + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + // Get pass config `tl.config_index_bitwidth` + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + Optional opt_config_index_bitwidth = + ctxt->GetConfig(kConfigIndexBitwidth, Optional()); + if (opt_config_index_bitwidth.defined()) { + int config_index_bitwidth = opt_config_index_bitwidth.value()->value; + n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(n->body); + } + // Legalize out-of-bound indices to be int64 + n->body = IndexLegalizer::Rewrite(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth", + ConfigIndexBitwidth); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/eliminate_storage_sync_for_mbarrier.cc b/tilelang/original/src/transform/eliminate_storage_sync_for_mbarrier.cc new file mode 100644 index 0000000000000000000000000000000000000000..504de732ca80b9a25fd8d46ce4f6c705260c039f --- /dev/null +++ b/tilelang/original/src/transform/eliminate_storage_sync_for_mbarrier.cc @@ -0,0 +1,125 @@ +/*! + * \file eliminate_storage_sync_for_mbarrier.cc + */ +#include "../op/builtin.h" +#include "./storage_access.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +class Eliminator : public IRMutatorWithAnalyzer { +public: + static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) { + arith::Analyzer analyzer; + Eliminator transformer(&analyzer); + return transformer.VisitStmt(stmt); + } + + Eliminator(arith::Analyzer *analyzer) : IRMutatorWithAnalyzer(analyzer) { + im_mbarrier_for_ = false; + in_mbarrier_region_ = false; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "thread_extent") { + if (const auto *var = op->node.as()) { + if (var->name_hint == "threadIdx.x") { + thread_extent_ = op; + } + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + const CallNode *call = nullptr; + if (op->value->IsInstance()) { + call = op->value.as(); + if (call->op.same_as(builtin::tvm_storage_sync())) { + // Skip storage sync if we're in a region with mbarrier operations + // and we're not in a for loop with mbarrier operations + if (in_mbarrier_region_ || im_mbarrier_for_) { + return Stmt(); + } + } else if (call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(builtin::ptx_wait_barrier())) { + in_mbarrier_region_ = true; + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { + bool old_in_mbarrier = in_mbarrier_region_; + Stmt then_case = VisitStmt(op->then_case); + + Stmt ret; + if (op->else_case.defined()) { + in_mbarrier_region_ = old_in_mbarrier; + Stmt else_case = VisitStmt(op->else_case.value()); + in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_; + ret = IfThenElse(VisitExpr(op->condition), then_case, else_case); + } else { + in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_; + ret = IfThenElse(VisitExpr(op->condition), then_case, Stmt()); + } + return ret; + } + + Stmt VisitStmt_(const ForNode *op) final { + PostOrderVisit(tvm::ffi::GetRef(op), [&](const ObjectRef &node) { + if (const auto *call = node.as()) { + if (call->op.same_as(create_list_of_mbarrier()) || + call->op.same_as(mbarrier_wait_parity()) || + call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(builtin::ptx_cp_async_barrier())) { + im_mbarrier_for_ = true; + } + } + }); + auto stmt = IRMutatorWithAnalyzer::VisitStmt_(op); + im_mbarrier_for_ = false; + return stmt; + } + +private: + bool im_mbarrier_for_; + bool in_mbarrier_region_; + const AttrStmtNode *thread_extent_{nullptr}; +}; +using namespace tir::transform; + +namespace transform { + +tvm::transform::Pass EliminateStorageSyncForMBarrier() { + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + n->body = Eliminator::Substitute(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.EliminateStorageSyncForMBarrier", + {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier", + EliminateStorageSyncForMBarrier); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/flatten_buffer.cc b/tilelang/original/src/transform/flatten_buffer.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b68d337374971545ac0a52da3c677c40ff421f8 --- /dev/null +++ b/tilelang/original/src/transform/flatten_buffer.cc @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file flatten_buffer.cc + */ + +#include "arith/ir_mutator_with_analyzer.h" +#include "tir/transforms/ir_utils.h" +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported + * dimension for the TIR not contains opaque block. + */ +class BufferFlattener : public arith::IRMutatorWithAnalyzer { +public: + static PrimFunc Flatten(PrimFunc func) { + arith::Analyzer ana; + auto pass = BufferFlattener(&ana); + if (auto init_map = + func->attrs.GetAttr>(tl::attr::kLocalVarInit)) { + pass.local_var_init_map_ = init_map.value(); + } + auto writer = func.CopyOnWrite(); + pass.MarkBufferMapShapes(func); + writer->body = pass.VisitStmt(func->body); + // The buffers in func->buffer_map are deliberately left + // unflattened, as they are used for validation of user-provided + // arguments. The flattened buffers used in the updated + // function body alias the argument buffers. + return func; + } + +private: + using IRMutatorWithAnalyzer::VisitExpr; + using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt; + using IRMutatorWithAnalyzer::VisitStmt_; + + class Int64Promoter : public tir::IndexDataTypeRewriter { + public: + using Parent = IndexDataTypeRewriter; + + PrimExpr VisitExpr_(const VarNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + return cast(DataType::Int(64), tvm::ffi::GetRef(op)); + } + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const IntImmNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + return IntImm(DataType::Int(64), op->value); + } + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const CastNode *op) final { + if (op->dtype.is_int() && op->dtype.bits() < 64) { + return cast(DataType::Int(64), op->value); + } + return tvm::ffi::GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + // Force indices to be int64 + auto node = Downcast(Parent::VisitStmt_(op)); + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(Parent::VisitExpr_(op)); + return std::move(node); + } + }; + + explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {} + + Stmt VisitStmt_(const BlockNode *op) final { + ICHECK_EQ(op->match_buffers.size(), 0) + << "Unexpected MatchBufferRegion found during " + "tir.transform.FlattenBuffer. " + << "All MatchBufferRegion should be removed in " + "tir.transform.LowerMatchBuffer."; + + Block block = tvm::ffi::GetRef(op); + + Array alloc_buffers = op->alloc_buffers; + alloc_buffers.MutateByApply( + [this](const Buffer &buf) { return GetFlattenedBuffer(buf); }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } + + Array reads = op->reads; + reads.MutateByApply([this](BufferRegion region) { + return MutateBufferRegion(std::move(region)); + }); + if (!reads.same_as(op->reads)) { + block.CopyOnWrite()->reads = reads; + } + + Array writes = op->writes; + writes.MutateByApply([this](BufferRegion region) { + return MutateBufferRegion(std::move(region)); + }); + if (!writes.same_as(op->writes)) { + block.CopyOnWrite()->writes = writes; + } + + return StmtExprMutator::VisitStmt_(block.get()); + } + + Stmt VisitStmt_(const AllocateNode *op) final { + // Determine the flattened extents first, before stripping of + // DeclBuffer. + auto new_extents = [&]() -> Array { + if (op->extents.size() == 1) { + // No flattening required for buffers that are already flat + return op->extents; + } + + if (auto *decl_buffer = op->body.as()) { + // N-d buffer, use the DeclBuffer inside to determine how it + // should be flattened. + auto &buffer = decl_buffer->buffer; + bool matching_buffer = [&]() { + if (!decl_buffer->buffer->data.same_as(op->buffer_var)) { + return false; + } + if (op->dtype != buffer->dtype) { + return false; + } + if (op->extents.size() != buffer->shape.size()) { + return false; + } + ExprDeepEqual expr_equal; + for (size_t i = 0; i < op->extents.size(); i++) { + if (!expr_equal(op->extents[i], buffer->shape[i])) { + return false; + } + } + return true; + }(); + + if (matching_buffer) { + Buffer flattened = GetFlattenedBuffer(buffer); + return flattened->shape; + } else { + ICHECK(decl_buffer->buffer->axis_separators.empty()) + << "DeclBuffer node doesn't match Allocate extents, but also " + "shouldn't be " + "flattened to 1-d physical memory"; + } + } + + // Fallback, this is an allocation without a matching DeclBuffer + PrimExpr flat_extent = 1; + for (const auto &dim : op->extents) { + flat_extent *= dim; + } + return {flat_extent}; + }(); + + Allocate alloc = Downcast(StmtExprMutator::VisitStmt_(op)); + + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (alloc->dtype == DataType::Bool()) { + alloc.CopyOnWrite()->dtype = DataType::Int(8); + } + + if (!new_extents.same_as(alloc->extents)) { + alloc.CopyOnWrite()->extents = new_extents; + } + if (!local_var_init_map_.empty()) { + auto init_it = local_var_init_map_.find(alloc->buffer_var); + if (init_it != local_var_init_map_.end()) { + const PrimExpr &init = (*init_it).second; + alloc.CopyOnWrite()->annotations.Set(tl::attr::kLocalVarInit, init); + } + } + + return std::move(alloc); + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + // TODO(rfc-70): Update the DeclBuffer node instead of + // stripping it out. Stripping it out in the current + // implementation as not all lowering passes support + // DeclBuffer. + return VisitStmt(op->body); + } + + Buffer GetFlattenedBuffer(const Buffer &buf) { + auto it = buffer_remap_.find(buf); + if (it != buffer_remap_.end()) { + return it->second; + } + auto flattened = buf.GetFlattenedBuffer(); + auto writer = flattened.CopyOnWrite(); + + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (flattened->dtype == DataType::Bool()) { + writer->dtype = DataType::Int(8); + } + // canonicalize shape + for (size_t i = 0; i < flattened->shape.size(); ++i) { + writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); + } + + buffer_remap_[buf] = flattened; + return flattened; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + bool store_returns_bool = (op->value.dtype() == DataType::Bool()); + store = VisitBufferAccess(store); + + // Handle casts from the value's dtype to the dtype of the + // backing array. + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (store_returns_bool) { + ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + auto writer = store.CopyOnWrite(); + writer->value = tvm::cast(DataType::Int(8), store->value); + return std::move(store); + } + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + bool load_returns_bool = (op->dtype == DataType::Bool()); + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + load = VisitBufferAccess(load); + // Handle casts from dtype of the backing array to value's dtype. + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (load_returns_bool && !under_address_of) { + ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + load.CopyOnWrite()->dtype = DataType::Int(8); + return tvm::cast(DataType::Bool(), load); + } else { + return std::move(load); + } + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::address_of())) { + under_address_of = true; + auto result = StmtExprMutator::VisitExpr_(op); + under_address_of = false; + return result; + } + return StmtExprMutator::VisitExpr_(op); + } + + Array GetSimplifiedElemOffset(const Buffer &buffer, + const Array &indices) { + auto flattened_indices = buffer->ElemOffset(indices); + Array safe_indices; + for (auto index : flattened_indices) { + auto int_bound = analyzer_->const_int_bound(index); + DataType dtype = index->dtype; + if (dtype.is_int() && dtype.bits() < 64) { + int64_t max_value = int_bound->max_value; + int64_t min_value = int_bound->min_value; + const int64_t type_max = (1LL << (dtype.bits() - 1)); + const int64_t type_min = -(1LL << (dtype.bits() - 1)); + + if (max_value >= (type_max - 1) || min_value < type_min) { + Int64Promoter promoter; + for (auto &index : flattened_indices) { + safe_indices.push_back(promoter(index)); + } + } else { + safe_indices.push_back(index); + } + } else { + safe_indices.push_back(index); + } + } + return this->IterMapSimplifyWithContext(safe_indices, false); + } + + template Node VisitBufferAccess(Node node) { + ICHECK(node->buffer.defined()); + auto flattened_indices = + GetSimplifiedElemOffset(node->buffer, node->indices); + Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); + + auto writer = node.CopyOnWrite(); + writer->buffer = flattened_buffer; + writer->indices = flattened_indices; + return node; + } + + BufferRegion MutateBufferRegion(BufferRegion region) { + Buffer orig_buf = region->buffer; + Buffer flattened_buf = GetFlattenedBuffer(orig_buf); + if (flattened_buf.same_as(orig_buf)) { + return region; + } + + Array min_values; + Array max_values; + for (const auto &range : region->region) { + min_values.push_back(range->min); + max_values.push_back(range->min + range->extent - 1); + } + + Array flattened_min = + GetSimplifiedElemOffset(orig_buf, min_values); + Array flattened_max = + GetSimplifiedElemOffset(orig_buf, max_values); + + Array flattened_ranges; + ICHECK_EQ(flattened_min.size(), flattened_max.size()); + for (size_t i = 0; i < flattened_min.size(); i++) { + flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); + } + + return BufferRegion(flattened_buf, flattened_ranges); + } + + /*! \brief Whether the current buffer is under address_of */ + bool under_address_of = false; + /*! \brief Map of buffers being remapped. */ + std::unordered_map + buffer_remap_; + + /*! \brief The updated external buffer map. */ + Map updated_extern_buffer_map_; + + /*! \brief Local var initializers preserved from block annotations. */ + Map local_var_init_map_; +}; + +PrimFunc FlattenBufferRewriter(PrimFunc f) { + return BufferFlattener::Flatten(std::move(f)); +} + +using namespace tir::transform; +tvm::transform::Pass FlattenBuffer() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return FlattenBufferRewriter(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/frontend_legalize.cc b/tilelang/original/src/transform/frontend_legalize.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffb4b1a5342ff2efb1d28691713c422d0d7ec68d --- /dev/null +++ b/tilelang/original/src/transform/frontend_legalize.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file frontend_legalize.cc + * \brief Legalize the program from frontend + */ + +#include +#include +#include +#include + +#include "arith/ir_mutator_with_analyzer.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class LetInliner : public arith::IRMutatorWithAnalyzer { +public: + static PrimFunc Substitute(PrimFunc f) { + arith::Analyzer analyzer; + LetInliner substituter(&analyzer); + PrimFuncNode *fptr = f.CopyOnWrite(); + fptr->body = substituter.VisitStmt(f->body); + return f; + } + +private: + using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; + + Stmt VisitStmt_(const ForNode *node) final { + if (node->kind == ForKind::kParallel) { + parallel_for_scope_++; + } + auto n = StmtExprMutator::VisitStmt_(node); + if (node->kind == ForKind::kParallel) { + parallel_for_scope_--; + } + return n; + } + + PrimExpr VisitExpr_(const VarNode *node) final { + if (let_bindings_.count(node)) { + return arith::IRMutatorWithAnalyzer::VisitExpr(let_bindings_[node]); + } else { + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } + } + + Stmt VisitStmt_(const LetStmtNode *node) final { + let_bindings_[node->var.get()] = node->value; + return arith::IRMutatorWithAnalyzer::VisitStmt(node->body); + } + + PrimExpr VisitExpr_(const LetNode *node) final { + let_bindings_[node->var.get()] = node->value; + return arith::IRMutatorWithAnalyzer::VisitExpr(node->body); + } + + int parallel_for_scope_ = 0; + std::unordered_map let_bindings_; +}; + +using namespace tir::transform; + +Pass LetInline() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return LetInliner::Substitute(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LetInline", LetInline); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/hoist_nonrestrict_params.cc b/tilelang/original/src/transform/hoist_nonrestrict_params.cc new file mode 100644 index 0000000000000000000000000000000000000000..90db747e88a6cad9c73107be8b2f7168d89d2e50 --- /dev/null +++ b/tilelang/original/src/transform/hoist_nonrestrict_params.cc @@ -0,0 +1,133 @@ +/* + * Hoist tl.non_restrict_params block annotation(s) to PrimFunc attribute. + * + * Previously, we only looked at the root block. This version recursively + * scans all blocks, unions any tl.non_restrict_params entries it finds, + * merges with any existing PrimFunc-level attribute, then writes the + * deduplicated result back to the PrimFunc attrs. This makes annotation + * placement within the function body flexible for frontends. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { +using namespace tvm::tir; + +class NonRestrictCollector : public StmtVisitor { +public: + void Collect(const Stmt &stmt) { VisitStmt(stmt); } + + Array Result() const { + Array out; + out.reserve(collected_.size()); + for (const Var &v : collected_) + out.push_back(v); + return out; + } + +private: + static std::string NormalizeName(const std::string &s) { + if (s.size() >= 8 && s.rfind("_handle") == s.size() - 7) { + return s.substr(0, s.size() - 7); + } + return s; + } + + void MaybeInsert(const Var &v) { + if (!v.defined()) + return; + const VarNode *p = v.get(); + if (seen_ptr_.count(p)) + return; + // Also dedup by normalized name to be robust w.r.t recreated Vars + std::string norm = NormalizeName(v->name_hint); + if (seen_name_.count(norm)) + return; + seen_ptr_.insert(p); + seen_name_.insert(std::move(norm)); + collected_.push_back(v); + } + + void VisitStmt_(const BlockNode *op) final { + auto it = op->annotations.find(attr::kNonRestrictParams); + if (it != op->annotations.end()) { + if (const auto *arr = (*it).second.as()) { + // Downcast directly to Array for convenience + Array vars = tvm::Downcast>((*it).second); + for (const Var &v : vars) { + MaybeInsert(v); + } + } + } + // Recurse into child statements + StmtVisitor::VisitStmt_(op); + } + + std::vector collected_; + std::unordered_set seen_ptr_; + std::unordered_set seen_name_; +}; + +static PrimFunc HoistNonRestrictParams(PrimFunc f) { + if (!f.defined()) + return f; + + NonRestrictCollector collector; + collector.Collect(f->body); + Array from_blocks = collector.Result(); + + // Merge with any existing PrimFunc-level attribute if present + if (auto opt_existing = f->GetAttr>(attr::kNonRestrictParams)) { + for (const Var &v : opt_existing.value()) { + // Reuse the collector's dedup logic by temporarily constructing a new + // collector Alternatively, do a small inline dedup mirroring MaybeInsert + // Here we inline a simplified pointer-based dedup plus name-based + // fallback + bool exists = false; + for (const Var &cur : from_blocks) { + if (cur.get() == v.get() || cur->name_hint == v->name_hint) { + exists = true; + break; + } + } + if (!exists) + from_blocks.push_back(v); + } + } + + if (from_blocks.empty()) + return f; + + return WithAttr(std::move(f), attr::kNonRestrictParams, + std::move(from_blocks)); +} + +namespace transform { + +tvm::transform::Pass HoistNonRestrictParams() { + auto pass_func = [](PrimFunc f, const IRModule &, + const tvm::transform::PassContext &) { + return tvm::tl::HoistNonRestrictParams(std::move(f)); + }; + return tvm::tir::transform::CreatePrimFuncPass( + pass_func, 0, "tl.HoistNonRestrictParams", {}); +} + +} // namespace transform + +} // namespace tl +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.HoistNonRestrictParams", + tvm::tl::transform::HoistNonRestrictParams); +} diff --git a/tilelang/original/src/transform/if_stmt_binding.cc b/tilelang/original/src/transform/if_stmt_binding.cc new file mode 100644 index 0000000000000000000000000000000000000000..5da796c9de292903ff6e1a6b7b19edf1f34c4067 --- /dev/null +++ b/tilelang/original/src/transform/if_stmt_binding.cc @@ -0,0 +1,90 @@ +/*! + * \file if_stmt_binding.cc + * \brief Bind the If Stmt to each Stmt in SeqStmt + */ + +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class IfStmtBindingRewriter : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f) { + auto rewriter = IfStmtBindingRewriter(); + f.CopyOnWrite()->body = rewriter(f->body); + return f; + } + +private: + IfStmtBindingRewriter() = default; + + Stmt VisitStmt_(const IfThenElseNode *op) final { + auto condition = op->condition; + auto then_case = VisitStmt(op->then_case); + Optional else_case = op->else_case; + if (else_case.defined()) { + return tvm::ffi::GetRef(op); + } + ICHECK(then_case.defined()) << "then_case must be defined"; + ICHECK(!else_case.defined()) << "else_case must be undefined"; + + auto bind_if_stmt = [](const Optional &body, + const PrimExpr &condition) -> Stmt { + if (body.defined()) { + auto stmt = body.value(); + if (auto seq_stmt = stmt.as()) { + Array seq_; + for (auto s : seq_stmt->seq) { + seq_.push_back(IfThenElse(condition, s, Stmt())); + } + return SeqStmt(std::move(seq_)); + } else { + return IfThenElse(condition, stmt, Stmt()); + } + } else { + return Stmt(); + } + }; + + Array new_seq; + + if (then_case.defined()) { + new_seq.push_back(bind_if_stmt(then_case, condition)); + } + return new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq)); + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + Array seq; + for (auto stmt : op->seq) { + seq.push_back(VisitStmt(stmt)); + } + return SeqStmt(std::move(seq)); + } +}; + +using namespace tir::transform; +tvm::transform::Pass IfStmtBinding() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return IfStmtBindingRewriter::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/inject_assumes.cc b/tilelang/original/src/transform/inject_assumes.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a5fc62ca2bb313d027440be26c89b86ee5d1f6f --- /dev/null +++ b/tilelang/original/src/transform/inject_assumes.cc @@ -0,0 +1,186 @@ +/*! + * \file inject_assumes.cc + * \brief Inject assumes on buffer's shape boundary check. Also convert + * existing assumes to AttrNodes. + */ + +#include "common/assume.h" +#include "tvm/arith/analyzer.h" +#include "tvm/ffi/optional.h" +#include "tvm/ir/expr.h" +#include "tvm/ir/transform.h" +#include "tvm/node/structural_hash.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/op.h" +#include "tvm/tir/stmt.h" +#include "tvm/tir/stmt_functor.h" +#include "tvm/tir/transform.h" + +#include + +namespace tvm::tl { +using namespace tir; + +class AssumeInjector : public tvm::tir::StmtExprMutator { + using Base = tvm::tir::StmtExprMutator; + +public: + AssumeInjector(PrimFunc f) : f(f) {} + static PrimFunc Substitute(PrimFunc f) { + auto injector = AssumeInjector(f); + f.CopyOnWrite()->body = injector(f->body); + return f; + } + +private: + struct AssumeCreator { + struct Item { + PrimExpr expr; + std::vector buffers; + }; + + tvm::StructuralHash sh; + tvm::StructuralEqual se; + // grouped by expr, since the amount of variadic shape symbols is usually + // much smaller than buffer + std::vector items; + // hash => index in items + std::unordered_map> buckets; + void addExpr(PrimExpr e, Buffer buffer) { + size_t h = sh(e); + auto &bucket = buckets[h]; + auto it = std::find_if(bucket.begin(), bucket.end(), [&](size_t y) { + return se(e, items[y].expr, true); + }); + if (it == bucket.end()) { + auto index = items.size(); + items.push_back({e, {buffer}}); + bucket.push_back(index); + } else { + items[*it].buffers.push_back(buffer); + } + } + + void addBuffer(Buffer buf) { + for (auto shape : buf->shape) { + if (shape->IsInstance()) + continue; + addExpr(shape, buf); + } + } + + Stmt build(Stmt body) { + auto analyzer = arith::Analyzer{}; + for (const auto &e : items) { + auto simplified = + analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype))); + std::stringstream ss; + ss << "Buffer shape should be greater than 0: shape `" << e.expr + << "` from buffer "; + for (size_t i = 0; i < e.buffers.size(); i++) { + if (i) + ss << ", "; + ss << "`" << e.buffers[i]->name << "`"; + } + body = AttrStmt(simplified, tir::attr::tilelang_assume, + StringImm(ss.str()), body); + } + return body; + } + }; + + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto body = VisitStmt(op->body); + AssumeCreator c; + c.addBuffer(op->buffer); + return DeclBuffer(op->buffer, c.build(body), op->span); + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + struct AssumeGroup { + std::optional e; + std::vector stmts; + }; + std::vector groups = {AssumeGroup{std::nullopt, {}}}; + for (size_t i = 0; i < op->seq.size(); i++) { + auto stmt = VisitStmt(op->seq[i]); + // Convert assume in evaluate form to assume attribute. + // By default, we have the following IR: + // T.assume(cond1) + // Stmt1 + // Stmt2 + // T.assume(cond2) + // This SeqStmt will be converted to: + // With(attr::tilelang_assume, cond1) { + // Stmt1 + // Stmt2 + // } + // With(attr::tilelang_assume, cond2) { + // ... + // } + if (auto e = GetAssumeExprInEvaluateForm(stmt)) { + groups.push_back(AssumeGroup{*e, {}}); + } else { + groups.back().stmts.push_back(stmt); + } + } + for (size_t i = groups.size(); i--;) { + auto &g = groups[i]; + if (g.e) { + Stmt body = g.stmts.size() == 1 ? g.stmts[0] : SeqStmt(g.stmts); + std::stringstream ss; + ss << "Assume: " << *(g.e); + AttrStmt attr = AttrStmt(*g.e, tir::attr::tilelang_assume, + StringImm(ss.str()), body); + groups[i - 1].stmts.push_back(attr); + } else { + ICHECK(i == 0) << "only the first group can have no assume"; + } + } + return groups[0].stmts.size() == 1 ? groups[0].stmts[0] + : SeqStmt(groups[0].stmts); + // return SeqStmt(groups[0].stmts); + } + + Stmt VisitStmt_(const BlockNode *op) final { + auto body = VisitStmt(op->body); + AssumeCreator c; + + // NOTE(chaofan): We only inject assumes from function arguments in the + // root block. + if (op->name_hint == "root") { + for (auto item : f->buffer_map) { + c.addBuffer(item.second); + } + } + for (auto item : op->alloc_buffers) { + c.addBuffer(item); + } + for (auto item : op->match_buffers) { + c.addBuffer(item->buffer); + } + + return Block(op->iter_vars, op->reads, op->writes, op->name_hint, + c.build(body), op->init, op->alloc_buffers, op->match_buffers, + op->annotations, op->span); + } + + PrimFunc f; +}; + +using namespace tir::transform; + +tvm::transform::Pass InjectAssumes() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return AssumeInjector::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes); +} + +} // namespace tvm::tl diff --git a/tilelang/original/src/transform/inject_fence_proxy.cc b/tilelang/original/src/transform/inject_fence_proxy.cc new file mode 100644 index 0000000000000000000000000000000000000000..6152789a2e7c2e78a62f757d8c2efa2e5b50f8de --- /dev/null +++ b/tilelang/original/src/transform/inject_fence_proxy.cc @@ -0,0 +1,329 @@ +/*! + * \file inject_fence_proxy.cc + * \brief Inject proxy fences between generic and async proxies (sm90+) + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using tvm::transform::PassContext; + +// Tracks what kind of proxy activity a statement performs so we can decide when +// to inject fences while traversing the IR. +enum class ProxyKind : uint8_t { + kUnknown, + kGeneric, + kAsync, + kMixed, + kNeutral, // Acts as a barrier and resets proxy state (e.g., fence + // instructions) +}; + +namespace { + +inline bool IsAsync(ProxyKind kind) { return kind == ProxyKind::kAsync; } +inline bool IsGeneric(ProxyKind kind) { return kind == ProxyKind::kGeneric; } + +// Merge two proxy kinds to represent the aggregate behaviour of a compound +// node. +inline ProxyKind CombineProxy(ProxyKind a, ProxyKind b) { + if (a == ProxyKind::kUnknown) + return b; + if (b == ProxyKind::kUnknown) + return a; + if (a == ProxyKind::kNeutral) + return b; + if (b == ProxyKind::kNeutral) + return a; + if (a == b) + return a; + return ProxyKind::kMixed; +} + +// We only need a fence when transitioning from generic operations to async +// ones. +inline bool NeedsFence(ProxyKind prev, ProxyKind curr) { + if (prev == ProxyKind::kUnknown || curr == ProxyKind::kUnknown) + return false; + if (prev == ProxyKind::kNeutral || curr == ProxyKind::kNeutral) + return false; + if (prev == ProxyKind::kMixed || curr == ProxyKind::kMixed) + return false; + return IsGeneric(prev) && IsAsync(curr); +} + +inline bool IsFenceCall(const CallNode *call) { + return call && call->op.same_as(fence_proxy_async()); +} + +// Identify async intrinsics emitted by TileLang or TVM that require a fence +// when they follow generic proxies. +bool IsAsyncIntrinsic(const CallNode *call) { + if (call == nullptr) { + return false; + } + + // TileLang async intrinsics + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_store()) || call->op.same_as(tma_store_arrive()) || + call->op.same_as(tma_store_wait()) || + call->op.same_as(ptx_cp_async_barrier_noinc()) || + call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs())) { + return true; + } + + // PTX async copy intrinsics + if (call->op.same_as(builtin::ptx_cp_async()) || + call->op.same_as(builtin::ptx_cp_async_barrier()) || + call->op.same_as(builtin::ptx_cp_async_bulk())) { + return true; + } + + // wgmma async intrinsics + if (call->op.same_as(tl_gemm()) || call->op.same_as(tl_gemm_sp())) { + return true; + } + + return false; +} + +// Known ops that must be treated as generic proxies (e.g. ldmatrix/stmatrix). +bool IsKnownGeneric(const CallNode *call) { + if (call == nullptr) { + return false; + } + return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) || + call->op.same_as(initialize_wgmma_descriptor()) || + call->op.same_as(initialize_tcgen05_descriptor()); +} + +ProxyKind ProxyFromAttrValue(const ObjectRef &value) { + if (const auto *str = value.as()) { + if (str->value == "async") { + return ProxyKind::kAsync; + } + if (str->value == "generic") { + return ProxyKind::kGeneric; + } + if (str->value == "neutral") { + return ProxyKind::kNeutral; + } + } + return ProxyKind::kUnknown; +} + +// TMA stores must be followed by the arrive/wait pair. We rewrite them as part +// of the pass to guarantee the proper synchronization semantics. +class TMAStoreSyncInjector : public StmtExprMutator { +public: + static PrimFunc Apply(PrimFunc f) { + if (!f->body.defined()) { + return f; + } + auto injector = TMAStoreSyncInjector(); + f.CopyOnWrite()->body = injector(f->body); + return f; + } + +private: + Stmt operator()(const Stmt &stmt) { return StmtExprMutator::VisitStmt(stmt); } + + Stmt VisitStmt_(const EvaluateNode *op) final { + Stmt mutated = StmtExprMutator::VisitStmt_(op); + const auto *node = mutated.as(); + if (const auto *call = node->value.as()) { + if (call->op.same_as(tma_store())) { + Array seq; + seq.push_back(mutated); + seq.push_back( + Evaluate(Call(DataType::Handle(), tma_store_arrive(), {}))); + seq.push_back(Evaluate(Call(DataType::Handle(), tma_store_wait(), {}))); + return SeqStmt(std::move(seq)); + } + } + return mutated; + } +}; + +// Main pass: track the proxy state while walking the IR and inject fences when +// switching from generic to async proxies. +class ProxyFenceInjector : public StmtMutator { +public: + static PrimFunc Apply(PrimFunc f) { + if (!f->body.defined()) { + return f; + } + ProxyFenceInjector injector; + f.CopyOnWrite()->body = injector.VisitStmt(f->body); + return f; + } + +private: + Stmt VisitStmt_(const SeqStmtNode *op) final { + Array seq; + seq.reserve(op->seq.size()); + + ProxyKind sequence_kind = ProxyKind::kUnknown; + ProxyKind prev_kind = ProxyKind::kUnknown; + + for (const Stmt &stmt : op->seq) { + Stmt new_stmt = VisitStmt(stmt); + ProxyKind current_kind = GetProxyKind(new_stmt); + + if (!seq.empty() && NeedsFence(prev_kind, current_kind)) { + Stmt fence = MakeFenceStmt(); + seq.push_back(fence); + prev_kind = GetProxyKind(fence); + } + + seq.push_back(new_stmt); + sequence_kind = CombineProxy(sequence_kind, current_kind); + prev_kind = current_kind; + } + + Stmt result = seq.size() == 1 ? seq[0] : SeqStmt(std::move(seq)); + SetProxyKind(result, sequence_kind); + return result; + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + const auto *evaluate = stmt.as(); + ProxyKind kind = ProxyKind::kGeneric; + + if (const auto *call = evaluate->value.as()) { + if (IsFenceCall(call)) { + kind = ProxyKind::kNeutral; + } else if (IsAsyncIntrinsic(call)) { + kind = ProxyKind::kAsync; + } else if (IsKnownGeneric(call)) { + kind = ProxyKind::kGeneric; + } else { + // We can now treat extern as Generic, since gemm and gemm_sp are never + // represented as call_extern nodes. They are call_intrin nodes and will + // be handled by IsAsyncIntrinsic above. + kind = ProxyKind::kGeneric; + } + } + + SetProxyKind(stmt, kind); + return stmt; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + SetProxyKind(stmt, ProxyKind::kGeneric); + return stmt; + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + const auto *node = stmt.as(); + ProxyKind kind = GetProxyKind(node->then_case); + if (node->else_case.defined()) { + kind = CombineProxy(kind, GetProxyKind(node->else_case.value())); + } + SetProxyKind(stmt, kind); + return stmt; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + const auto *node = stmt.as(); + ProxyKind body_kind = GetProxyKind(node->body); + SetProxyKind(stmt, body_kind); + return stmt; + } + + Stmt VisitStmt_(const BlockRealizeNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + const auto *node = stmt.as(); + SetProxyKind(stmt, GetProxyKind(node->block)); + return stmt; + } + + Stmt VisitStmt_(const BlockNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + const auto *node = stmt.as(); + ProxyKind kind = ProxyKind::kUnknown; + if (node->init.defined()) { + kind = CombineProxy(kind, GetProxyKind(node->init.value())); + } + kind = CombineProxy(kind, GetProxyKind(node->body)); + SetProxyKind(stmt, kind); + return stmt; + } + + Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); } + Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); } + Stmt VisitStmt_(const AssertStmtNode *op) final { + return VisitSingleBody(op); + } + Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); } + + template Stmt VisitSingleBody(const NodeType *op) { + Stmt stmt = StmtMutator::VisitStmt_(op); + const auto *node = stmt.as(); + ProxyKind body_kind = GetProxyKind(node->body); + SetProxyKind(stmt, body_kind); + return stmt; + } + + void SetProxyKind(const Stmt &stmt, ProxyKind kind) { + proxy_map_[stmt.get()] = kind; + } + + ProxyKind GetProxyKind(const Stmt &stmt) const { + if (!stmt.defined()) { + return ProxyKind::kUnknown; + } + auto it = proxy_map_.find(stmt.get()); + if (it == proxy_map_.end()) { + return ProxyKind::kUnknown; + } + return it->second; + } + + Stmt MakeFenceStmt() { + Stmt fence = Evaluate(Call(DataType::Handle(), fence_proxy_async(), {})); + SetProxyKind(fence, ProxyKind::kNeutral); + return fence; + } + + std::unordered_map proxy_map_; +}; + +} // namespace + +tvm::transform::Pass InjectFenceProxy() { + auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) { + f = TMAStoreSyncInjector::Apply(f); + f = ProxyFenceInjector::Apply(f); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", + {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/inject_pipeline.cc b/tilelang/original/src/transform/inject_pipeline.cc new file mode 100644 index 0000000000000000000000000000000000000000..79e78add94bad4c233d6a5de19fba939ed0065cb --- /dev/null +++ b/tilelang/original/src/transform/inject_pipeline.cc @@ -0,0 +1,1170 @@ +/*! + * \file inject_software_pipeline.cc + * \brief Transform annotated loops into pipelined one that parallelize + * producers and consumers + */ +#include +#include +#include + +#include +#include +#include + +#include "support/utils.h" +#include "tir/schedule/utils.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { +using namespace tir; +using namespace ffi; +namespace software_pipeline { + +struct LetWrapper { + Var var; + PrimExpr value; +}; + +/*! + * \brief Create a block and infer the access region with the given body. + * + * The result is a opaque block that doesn't contain any block iter vars. In + * case the body is a block realize without predicate, it is unnecessary to + * create a new block, the block of the block realize will be returned. + * + * \param body The body of the block. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \return The result block. + */ +Block MakeBlock(const Stmt &body, + const Map &buffer_data_to_buffer) { + if (const BlockRealizeNode *block_realize = body.as()) { + if (is_one(block_realize->predicate)) { + // no need to create a new block + return block_realize->block; + } + } + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ body); + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer); + BlockNode *n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + return block; +} + +/*! Structure that represents the provided annotation per block or loop. */ +struct PipelineAnnotation { + int stage; + int order; + bool async; + // Index of the statement in the original loop body order (SeqStmt order) + int original_idx = -1; +}; + +using PipelineInfo = std::unordered_map; + +struct BufferAccessInfo { + int def = -1; // the defining stage of the buffer + int use = -1; // the last using stage of the buffer +}; + +/*! + * \brief Rewriter for the body of the software pipeline. This pass inserts + * `floormod` to indices of the remapped buffer to select the version + * corresponding to the pipeline stage. + */ +class PipelineBodyRewriter : public StmtExprMutator { +public: + /*! + * \brief Constructor of PipelineBodyRewriter. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param buffer_remap The map from original buffer to the buffer with updated + * shape for multi-versioning in the software pipeline. \param pipeline_loop + * The original loop to be software pipelined. \param access_all_versions + * Whether all versions the buffers in the software pipeline are accessed. + * This will be used to update block access region. In the prologue and + * epilogue of a two-stage software pipeline, only one version of these + * buffers are accessed. + */ + PipelineBodyRewriter(const Map &buffer_data_to_buffer, + const Map &buffer_remap, + For pipeline_loop, bool access_all_versions) + : buffer_data_to_buffer_(buffer_data_to_buffer), + buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)), + access_all_versions_(access_all_versions) {} + +private: + BufferRegion + RewritePipelineBufferRegion(const BufferRegion &buffer_region) const { + auto it = buffer_remap_.find(buffer_region->buffer); + if (it != buffer_remap_.end()) { + Region new_region = buffer_region->region; + const Buffer &new_buffer = (*it).second; + // For pipeline buffers, relax the access region of the first dimension to + // full extent if access_all_versions == true + Range accessed_version = + access_all_versions_ + ? Range::FromMinExtent(0, new_buffer->shape[0]) + : Range::FromMinExtent( + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), + new_buffer->shape[0]), + Integer(1)); + new_region.insert(new_region.begin(), accessed_version); + return BufferRegion(new_buffer, new_region); + } + return buffer_region; + } + + PrimExpr RewriteBufferAccess(const Call &call, + const std::vector &arg_indices) { + auto product = [](const Array &input) { + return foldl( + [](PrimExpr a, PrimExpr b, Span span) { + return mul(std::move(a), std::move(b), std::move(span)); + }, + make_const(DataType::Int(32), 1), input); + }; + Array new_args = call->args; + for (int i : arg_indices) { + const Buffer &buffer = + buffer_data_to_buffer_.at(Downcast(call->args[i])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + const Buffer &new_buffer = (*it).second; + const PrimExpr &old_index = call->args[i + 1]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = product(buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = + old_index + + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; + new_args.Set(i + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } + + Stmt VisitStmt_(const BlockNode *op) final { + for (const Buffer &alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + BlockNode *n = block.CopyOnWrite(); + n->reads.MutateByApply([this](const BufferRegion &buffer_region) { + return RewritePipelineBufferRegion(buffer_region); + }); + n->writes.MutateByApply([this](const BufferRegion &buffer_region) { + return RewritePipelineBufferRegion(buffer_region); + }); + for (const Buffer &alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(alloc_buffer->data); + } + return block; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return store; + } + const Buffer &new_buffer = (*it).second; + auto *n = store.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = floormod( + (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return load; + } + const Buffer &new_buffer = (*it).second; + auto *n = load.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = floormod( + (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return load; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(builtin::tvm_access_ptr())) { + return RewriteBufferAccess(call, {1}); + } + return call; + } + + Map buffer_data_to_buffer_; + Map buffer_remap_; + For pipeline_loop_; + bool access_all_versions_; +}; + +/*! + * \brief Rewriter for the software pipeline that rewrite a loop into a + * pipelined one. + */ +class PipelineRewriter : public StmtExprMutator { +public: + PipelineRewriter(Map buffer_data_to_buffer, + const Array &pipeline_allocs, + const For &pipeline_loop, const PipelineInfo &pipeline_info, + const std::vector &loop_var_let_wrappers) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), + pipeline_info_(pipeline_info), + loop_var_let_wrappers_(loop_var_let_wrappers) {} + + Stmt BuildPipeline() { + // Step 1: Analyze accesses to the buffers in the pipeline and compute the + // number of versions need to maintain for each buffer. + std::unordered_map + infos = GetBufferAccessInfo(); + for (const Buffer &buffer : pipeline_allocs_) { + int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); + if (num_versions > 1) { + buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); + } + } + ordered_stmts_.resize(pipeline_info_.size()); + for (const auto &[block, anno] : pipeline_info_) { + ordered_stmts_.Set(anno.order, block); + } + + for (const Block &block : ordered_stmts_) { + int stage = pipeline_info_[block].stage; + if (pipeline_info_[block].async) { + auto &state = async_states[stage]; + state.producer_head = pipeline_loop_->min - 1; + for (auto write_region : block->writes) { + auto buffer = write_region->buffer; + state.dst_buffers.insert(buffer.get()); + if (buffer_remap_.count(buffer)) + state.dst_buffers.insert(buffer_remap_[buffer].get()); + } + } + } + std::unordered_set consumed; + for (const Block &block : ordered_stmts_) { + int stage = pipeline_info_[block].stage; + if (pipeline_info_[block].async) { + auto &state = async_states[stage]; + if (state.commit_groups.empty() || consumed.count(stage)) { + state.commit_groups.push_back({}); + } + state.commit_groups.back().push_back(pipeline_info_[block].order); + consumed.erase(stage); + for (auto write_region : block->writes) { + auto buffer = buffer_remap_.count(write_region->buffer) + ? buffer_remap_[write_region->buffer] + : write_region->buffer; + state.buffer_to_commit_group_[buffer.get()] = + state.commit_groups.size() - 1; + } + } + + for (auto read_region : block->reads) { + for (const auto &[producer_stage_id, producer_state] : async_states) { + if (producer_stage_id <= stage && + producer_state.writes(read_region->buffer)) { + consumed.insert(producer_stage_id); + } + } + } + } + + // Step 2: Emit the pipeline prologue, body and epilogue. + Stmt prologue = + EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true, + true, false); + Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false, + false, false); + + Stmt epilogue = + EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, + true, true, true); + SeqStmt stmt = SeqStmt({prologue, body, epilogue}); + + // Step 3: Make a new block that contains new buffer allocations after + // pipeline rewriting. + Array alloc_buffers; + for (const auto &alloc : pipeline_allocs_) { + alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); + buffer_data_to_buffer_.erase(alloc->data); + } + Block block = MakeBlock(stmt, buffer_data_to_buffer_); + block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); + return BlockRealize({}, Bool(true), block); + } + +private: + /*! + * \brief Analyze accesses to the buffers in the software pipeline. + * + * This method check the 'define' and 'use' stage of the buffers in the + * software pipeline, which can be used to compute the number of versions + * needed to maintain after rewriting. + */ + std::unordered_map + GetBufferAccessInfo() { + std::unordered_map + infos; + for (const auto &pair : pipeline_info_) { + const Block &block = pair.first; + int stage = pair.second.stage; + max_stage_ = std::max(max_stage_, stage); + + for (const BufferRegion &write : block->writes) { + if (!infos.count(write->buffer)) { + infos.emplace(write->buffer, BufferAccessInfo{}); + } + auto &info = infos.at(write->buffer); + if (info.def == -1) { + info.def = stage; + } else { + info.def = std::min(info.def, stage); + } + } + + for (const BufferRegion &read : block->reads) { + if (!infos.count(read->buffer)) { + infos.emplace(read->buffer, BufferAccessInfo{}); + } + auto &info = infos.at(read->buffer); + info.use = std::max(info.use, stage); + } + } + return infos; + } + + /*! + * \brief Check whether two regions have intersections. + * \param region1 The first region. + * \param region2 The second region. + * \return Whether region1 and region2 have intersections. + */ + bool MayConflict(const Region ®ion1, const Region ®ion2) { + ICHECK(region1.size() == region2.size()); + for (size_t i = 0; i < region1.size(); i++) { + Range dim1 = region1[i]; + Range dim2 = region2[i]; + auto int_set1 = arith::IntSet::FromRange(dim1); + auto int_set2 = arith::IntSet::FromRange(dim2); + if (arith::Intersect({int_set1, int_set2}).IsNothing()) { + return false; + } + } + return true; + } + + /*! + * \brief Compute the number of versions need to maintain for buffer accessed + * in the software pipeline. + * + * This method applies liveness analysis to the target buffer to compute the + * number of versions need to maintain during the software pipeline. + * Annotation `attr::double_buffer_scope` is handled here which provides a way + * to override the result of the analysis. Additional double buffering in the + * software pipeline can be useful to eliminate synchronizations in GPU + * devices. + * + * \param buffer The target buffer + * \param buffer_info The access information of the target buffer. + * \return The number of versions required for the target buffer. + */ + int ComputeBufferVersions(const Buffer &buffer, + const BufferAccessInfo &buffer_info) { + if (buffer_info.def == -1) { + // Keep the original number of versions as buffers defined outside the + // software pipeline should not be mutated. + return 1; + } + + // `use - def + 1` is a upper bound of the needed versions + // We optimize a few case where the number of versions can be smaller than + // the upper bound + int num_versions = buffer_info.use - buffer_info.def + 1; + if (num_versions >= 2) { + // A special case when `use - def + 1 == 2`. Double buffering is only + // needed in this case when these exists a reader block_i and a writer + // block_j such that order(block_i) < order(block_j) and stage(block_i) < + // stage(block_j) and the access regions of block_i and block_j overlap. + bool need_multi_version = false; + for (const auto &pair1 : pipeline_info_) { + const Block &writer_block = pair1.first; + const auto &writer_info = pair1.second; + + auto it1 = std::find_if(writer_block->writes.begin(), + writer_block->writes.end(), + [&](const BufferRegion &buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it1 == writer_block->writes.end()) { + continue; + } + + for (const auto &pair2 : pipeline_info_) { + const Block &reader_block = pair2.first; + const auto &reader_info = pair2.second; + auto it2 = std::find_if( + reader_block->reads.begin(), reader_block->reads.end(), + [&](const BufferRegion &buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it2 == reader_block->reads.end()) { + continue; + } + if (writer_info.order < reader_info.order && + writer_info.stage < reader_info.stage && + MayConflict((*it1)->region, (*it2)->region)) { + need_multi_version = true; + break; + } + } + } + if (!need_multi_version) { + num_versions--; + } + } + return num_versions; + } + + /*! + * \brief Rewrite buffer allocation to keep multiple versions of original + * buffer for pipelined accesses. \param buffer The buffer to be resized. + * \param num_versions The number of versions to keep. + * \return The resized buffer. + */ + Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { + ObjectPtr new_buffer = + tvm::ffi::make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); + if (!new_buffer->strides.empty()) { + ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); + } + return Buffer(new_buffer); + } + + // Per-stage states that need to be tracked across pipeline prologue, body, + // and epilogue. + struct AsyncStateGlobal { + // Buffers that this stage asynchronously writes. + std::unordered_set dst_buffers; + // An imaginary index that the latest async operation associated with this + // stage has written into. Only valid if all associated predicates are true, + // so that we can count the number of async invocations exactly. When it is + // valid, it is the "sum of extents of loops that have been executed" - 1, + // e.g. for epilogue it is prologue extent + body extent - 1. This is only + // needed to compute wait count for epilogue without async producers. + PrimExpr producer_head; + std::vector> commit_groups; + std::unordered_map buffer_to_commit_group_; + bool writes(const Buffer &buf) const { + return dst_buffers.count(buf.get()) > 0; + } + }; + + // Per-stage states that are local to each of pipeline prologue, body, and + // epilogue. + struct AsyncStateLocal { + struct PendingWait { + // The index into a list of blocks, where async_wait_queue should be + // attached at the beginning. + int insert_before; + // in_flight_count would be a more precise name, but the implementation + // uses wait_count for brevity. + PrimExpr wait_count{nullptr}; + + bool valid() const { return wait_count.defined(); } + }; + + std::vector pending_waits; + + // A symbolic expression representing the index the latest async operation + // associated with this stage has written into, at the "current" iteration. + Optional producer_head; + // the commit block's predicate + PrimExpr commit_predicate{nullptr}; + }; + + /*! Structure holding intermediate information for pipeline loop rewriting. */ + struct RewrittenBlockInfo { + int stage; + int order; + PrimExpr start; + PrimExpr end; + PrimExpr predicate; + Block block; + PrimExpr access_index; + bool is_async; + }; + + void PopulateWaitCounts(const std::vector &new_blocks, + std::map *async_states_local, + bool is_epilogue = false) { + // Precompute which orders are present in this emit, and their access_index + std::unordered_map order_to_access_index; + std::unordered_set present_orders; + for (const auto &nb : new_blocks) { + order_to_access_index[nb.order] = nb.access_index; + present_orders.insert(nb.order); + } + for (size_t i = 0; i < new_blocks.size(); ++i) { + // 1. Find the unique async producer stage + int producer_stage_idx = -1; + for (const auto &read_region : new_blocks[i].block->reads) { + for (const auto &[stage, state] : async_states) { + if (stage <= new_blocks[i].stage && + state.writes(read_region->buffer)) { + // Currently only a single async stage dependency is supported + ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage) + << "A dependency on multiple async stages is not supported"; + producer_stage_idx = stage; + } + } + } + if (producer_stage_idx == -1) { + // This block does not depend on any async producer + continue; + } + const auto &state = async_states[producer_stage_idx]; + + auto &dep_local_state = (*async_states_local)[producer_stage_idx]; + + // 2. Use buffer_to_commit_group_ to find all actually dependent commit + // groups + std::unordered_set dependent_groups; + for (const auto &read_region : new_blocks[i].block->reads) { + auto it = state.buffer_to_commit_group_.find(read_region->buffer.get()); + if (it != state.buffer_to_commit_group_.end()) { + dependent_groups.insert(it->second); + } + } + + // If there is no dependent commit group, no wait needs to be inserted + if (dependent_groups.empty()) { + continue; + } + + // 3. Compute wait = max_g max(0, t_consumer - committed_before[g]) + PrimExpr t_consumer = new_blocks[i].access_index; + PrimExpr wait_expr = make_zero(t_consumer.dtype()); + + PrimExpr current_head = dep_local_state.producer_head.defined() + ? dep_local_state.producer_head.value() + : state.producer_head; + int consumer_order = new_blocks[i].order; + + for (int g : dependent_groups) { + const auto &group = state.commit_groups[g]; + if (group.empty()) + continue; + int commit_order = group.back(); + bool commit_present = present_orders.count(commit_order) > 0; + + PrimExpr committed_before; + if (commit_present && commit_order <= consumer_order) { + // Commit point is in this iteration and earlier than the current + // consumer; this iteration's head is visible + auto commit_predicate = dep_local_state.commit_predicate; + if (analyzer_.CanProve(!commit_predicate, + arith::ProofStrength::kSymbolicBound)) { + // it means the commit block is not executed in this iteration + committed_before = new_blocks[i].start - 1; + } else if (is_epilogue) { + committed_before = new_blocks[i].start - 1; + } else { + committed_before = order_to_access_index.at(commit_order); + } + } else { + // Commit point is later than the current consumer or not in this + // iteration; only the previous iteration's head is visible + if (dep_local_state.producer_head.defined()) { + auto commit_predicate = dep_local_state.commit_predicate; + if (analyzer_.CanProve(!commit_predicate, + arith::ProofStrength::kSymbolicBound)) { + committed_before = new_blocks[i].start - 1; + } else if (is_epilogue) { + committed_before = new_blocks[i].start - 1; + } else { + committed_before = current_head - 1; + } + } + } + + wait_expr = analyzer_.Simplify(committed_before - t_consumer); + } + + wait_expr = analyzer_.Simplify(wait_expr); + dep_local_state.pending_waits.push_back({static_cast(i), wait_expr}); + } + } + + // Given pipelined blocks and async-related information, generate final loop + // statements with async scopes (if any). + Array CompletePipelineLoopStatements( + const std::vector &blocks, + const std::map &async_states_local) const { + std::vector new_blocks = blocks; + for (const auto &[stage_id, state] : async_states_local) { + for (const auto &pw : state.pending_waits) { + auto &block = new_blocks[pw.insert_before].block; + BlockNode *n = block.CopyOnWrite(); + auto zero = make_zero(DataType::Int(32)); + n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, + AttrStmt(zero, tir::attr::async_wait_inflight_count, + pw.wait_count, n->body)); + } + } + + // mark the last async stmt as commit + std::unordered_set commit_group_indices; + for (const auto &[stage_id, state] : async_states) { + for (size_t i = 0; i < state.commit_groups.size(); ++i) { + commit_group_indices.insert(state.commit_groups[i].back()); + } + } + + Array stmts; + + for (size_t i = 0; i < new_blocks.size(); i++) { + Block block = new_blocks[i].block; + if (commit_group_indices.count(new_blocks[i].order)) { + auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, + new_blocks[i].stage, block->body); + block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); + } + stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block)); + } + + return stmts; + } + + /*! + * \brief Emit the pipeline loop in the given range. + * \param start The start of the range + * \param end The end of the range + * \param unroll_loop Whether the loop should be unrolled. + * \return The result loop. + */ + Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop, + bool need_bound_check, bool is_epilogue = false) { + PrimExpr new_loop_var; + PrimExpr extent = end - start; + auto make_nop = []() { + return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); + }; + + bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); + if (is_unit_loop) { + new_loop_var = start; // use constants as the loop var for unit loops + } else { + new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); + // Bind the iteration domain [start, end) to strengthen analyzer facts. + analyzer_.Bind(Downcast(new_loop_var), + Range::FromMinExtent(start, end - start)); + } + // Keep the bound constraints active for all analysis below. + // Only meaningful when the loop var is symbolic (non-unit loop). + std::unique_ptr> ctx_lb_guard; + std::unique_ptr> ctx_ub_guard; + if (!is_unit_loop) { + Var loop_iter = Downcast(new_loop_var); + ctx_lb_guard.reset( + new With(&analyzer_, loop_iter >= start)); + ctx_ub_guard.reset( + new With(&analyzer_, loop_iter < end)); + } + + std::vector new_blocks; + + // Async related + std::map async_states_local; + + for (const Block &block : ordered_stmts_) { + int stage = pipeline_info_.at(block).stage; + int order = pipeline_info_.at(block).order; + + PrimExpr inbound = Bool(true); + PrimExpr skewed_loop_var = new_loop_var - stage; + if (need_bound_check) + inbound = And( + pipeline_loop_->min <= skewed_loop_var, + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent)); + + Block new_block = Downcast( + PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, + pipeline_loop_, max_stage_ != 1)(block)); + + PrimExpr delta = start - pipeline_loop_->min; + // This variable corresponds to + // - "producer_head" if this stage is an async producer + // - "consumer_head" if this stage reads from asynchronously written + // buffers. + PrimExpr normalized_access_index = + is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; + + normalized_access_index = analyzer_.Simplify(normalized_access_index); + + // Adjust the block predicate and the body according to the final loop + // bound + // [pipeline_loop_->min, extent). + if (!is_unit_loop) { + Var loop_iter = Downcast(new_loop_var); + inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); + } + new_block = Downcast(Substitute( + new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); + + // If there were Let-wrappers outside the original pipeline body that + // depended on the pipeline loop var, push them into each rewritten + // block with the correct per-block substitution. + if (!loop_var_let_wrappers_.empty()) { + BlockNode *n = new_block.CopyOnWrite(); + Stmt inner = n->body; + for (const auto &lw : loop_var_let_wrappers_) { + PrimExpr substituted = Substitute( + lw.value, {{pipeline_loop_->loop_var, normalized_access_index}}); + inner = LetStmt(lw.var, substituted, inner); + } + n->body = inner; + } + + if (pipeline_info_[block].async) { + auto &local_state = async_states_local[stage]; + local_state.producer_head = normalized_access_index; + local_state.commit_predicate = inbound; + BlockNode *n = new_block.CopyOnWrite(); + n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, + 1, n->body); + } + + new_blocks.push_back({stage, order, start, end, inbound, new_block, + normalized_access_index, + pipeline_info_[block].async}); + } + + PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue); + + auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); + + Stmt new_loop{nullptr}; + + if (stmts.empty()) { + return make_nop(); + } + + if (stmts.size() == 1) { + new_loop = stmts[0]; + } else { + new_loop = SeqStmt(stmts); + } + + if (!is_unit_loop) { + Map preserved_annotations; + for (const auto &kv : pipeline_loop_->annotations) { + const String &key = kv.first; + if (kv.first != tir::attr::software_pipeline_stage && + kv.first != tir::attr::software_pipeline_order && + kv.first != tir::attr::software_pipeline_async_stages) { + preserved_annotations.Set(key, kv.second); + } + } + new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, + std::move(new_loop), std::nullopt, preserved_annotations); + } + // Update producer heads in the global async states. + for (const auto &[stage_id, state] : async_states_local) { + async_states[stage_id].producer_head += extent; + } + + return BlockRealize({}, Bool(true), + MakeBlock(new_loop, buffer_data_to_buffer_)); + } + + arith::Analyzer analyzer_; + Map buffer_data_to_buffer_; + Array pipeline_allocs_; + For pipeline_loop_; + PipelineInfo pipeline_info_; + int max_stage_ = -1; + Map buffer_remap_; + Array ordered_stmts_; + std::map async_states; + std::vector loop_var_let_wrappers_; +}; + +/*! + * \brief Build the dependency graph among a array of blocks. + * \param[in] blocks The array of blocks. + * \param[out] dep_src2dst Optional, a map to store dependency edges from the + * source to the destination. \param[out] dep_dst2src Optional, a map to store + * dependency edges from the destination to the source. + */ +void BuildDependencyGraph(const Array &blocks, + std::unordered_map, ObjectPtrHash, + ObjectPtrEqual> *dep_src2dst, + std::unordered_map, ObjectPtrHash, + ObjectPtrEqual> *dep_dst2src) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_writers; + + for (const Block &block : blocks) { + for (const BufferRegion &read : block->reads) { + auto it = buffer_writers.find(read->buffer->data); + if (it != buffer_writers.end()) { + for (const Block &writer : it->second) { + if (dep_src2dst != nullptr) { + (*dep_src2dst)[writer].push_back(block); + } + if (dep_dst2src != nullptr) { + (*dep_dst2src)[block].push_back(writer); + } + } + } + } + for (const BufferRegion &write : block->writes) { + buffer_writers[write->buffer->data].push_back(block); + } + } +} + +class PipelineInjector : private StmtExprMutator { +public: + static Stmt Inject(const PrimFunc &func) { + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + PipelineInjector injector(global_symbol); + for (const auto &kv : func->buffer_map) { + const Buffer &buffer = kv.second; + injector.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + return injector(func->body); + } + +private: + explicit PipelineInjector(Optional global_symbol) + : global_symbol_(std::move(global_symbol)) {} + + /*! + * \brief Check the pipeline satisfies the following conditions: + * 1. No conflicting order: The order of each statement should be unique. + * 2. Reordering of statements doesn't break buffer access dependencies. + * Specifically, for dependency (e.g. read-after-write) from statement A to + * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) == + * stage(B) and order(A) < order(B) + */ + void ValidatePipelineBody(const PipelineInfo &pipeline_info, + const Array &original_order) { + std::unordered_set used_orders; + std::unordered_map stage_max_order; + std::unordered_map order_to_block; + std::unordered_map block_to_stage; + for (const Block &block : original_order) { + const auto &stmt_info = pipeline_info.at(block); + int order = stmt_info.order; + CHECK(!used_orders.count(order)) + << "ValueError: Two statements in the software pipeline cannot have " + "the same order"; + used_orders.insert(order); + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + dep_src2dst; + BuildDependencyGraph(original_order, &dep_src2dst, nullptr); + + for (const auto &pair : dep_src2dst) { + const Block &src = pair.first; + const auto &src_info = pipeline_info.at(src); + const Array &dsts = pair.second; + for (const Block &dst : dsts) { + const auto &dst_info = pipeline_info.at(dst); + CHECK_LE(src_info.stage, dst_info.stage) + << "ValueError: statement " << dst << " in stage " << dst_info.stage + << " cannot depends on statement " << src << " in a later stage " + << src_info.stage; + if (src_info.stage == dst_info.stage) { + CHECK_LT(src_info.order, dst_info.order) + << "ValueError: two statements with buffer " + "access dependency in the same stage of the " + "software pipeline cannot be reordered"; + } + } + } + } + + Stmt VisitStmt_(const ForNode *op) final { + // Step 1: Recursively rewrite the children first. + For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); + if (!HasPipelineAnnotation(op)) { + return for_node; + } + // Step 2: Find the body and buffer allocations of the pipeline. The body + // can be direct child of the for-loop. If the for-loop has BlockRealize as + // its child, the pipeline body will be the child of the block. + Stmt pipeline_body_root{nullptr}; + bool pipeline_body_from_block = false; + Array pipeline_allocs; + if (const auto *realize = for_node->body.as()) { + const auto &block = realize->block; + for (const auto &buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body_root = block->body; + pipeline_allocs = block->alloc_buffers; + pipeline_body_from_block = true; + } else { + pipeline_body_root = for_node->body; + } + + const SeqStmtNode *pipeline_body_seq = nullptr; + std::vector> rewrap_fns; + std::vector loop_var_let_wrappers; + auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) { + Any node = attr->node; + String attr_key = attr->attr_key; + PrimExpr value = attr->value; + Span span = attr->span; + rewrap_fns.emplace_back( + [node = std::move(node), attr_key = std::move(attr_key), + value = std::move(value), span](Stmt body) -> Stmt { + return AttrStmt(node, attr_key, value, body, span); + }); + }; + { + Stmt current = pipeline_body_root; + while (true) { + if (const auto *seq_stmt = current.as()) { + pipeline_body_seq = seq_stmt; + break; + } + if (const auto *if_then_else = current.as()) { + ICHECK(!if_then_else->else_case.defined()) + << "InjectSoftwarePipeline: Can't handle the body of the loop " + "because the IfThenElse node has an else branch"; + PrimExpr condition = if_then_else->condition; + Span span = if_then_else->span; + rewrap_fns.emplace_back( + [condition = std::move(condition), span](Stmt body) -> Stmt { + return IfThenElse(condition, body, Stmt(), span); + }); + current = if_then_else->then_case; + continue; + } + if (const auto *let_stmt = current.as()) { + // If this Let value uses the pipeline loop var, record it and push + // inside each rewritten block later so the loop var can be + // substituted with the correct per-iteration index. Otherwise, keep + // it as a normal wrapper. + bool uses_loop_var = UsesVar( + let_stmt->value, + [v = op->loop_var.get()](const VarNode *vn) { return vn == v; }); + if (uses_loop_var) { + loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value}); + } else { + Var var = let_stmt->var; + PrimExpr value = let_stmt->value; + Span span = let_stmt->span; + rewrap_fns.emplace_back([var = std::move(var), + value = std::move(value), + span](Stmt body) -> Stmt { + return LetStmt(var, value, body, span); + }); + } + current = let_stmt->body; + continue; + } + if (const auto *attr = current.as()) { + append_attr_wrapper(attr); + current = attr->body; + continue; + } + LOG(FATAL) << "ValueError: The body of the software pipeline should be " + << "SeqStmt, got " << current->GetTypeKey(); + } + } + ICHECK(pipeline_body_seq != nullptr); + + // Step 3: Blockize the components of the pipeline. Each child of the + // pipelined loop will be converted into a block. + PipelineInfo pipeline_info; + Array original_order; // pipeline body blocks in the original order + + auto f_add_child = [&](const Stmt &child) { + original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); + }; + for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { + const Stmt &child = pipeline_body_seq->seq[i]; + const auto *nested_block_realize = child.as(); + if (nested_block_realize && is_one(nested_block_realize->predicate) && + nested_block_realize->block->body->IsInstance()) { + const Block &nested_pipeline_block = nested_block_realize->block; + ICHECK(nested_pipeline_block->match_buffers + .empty()); // match_buffer should have been lowered + for (const auto &buffer : nested_pipeline_block->alloc_buffers) { + pipeline_allocs.push_back(buffer); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + } + f_add_child(child); + } + + auto pipeline_stages = Downcast>( + op->annotations.at(tir::attr::software_pipeline_stage)); + auto pipeline_orders = Downcast>( + op->annotations.at(tir::attr::software_pipeline_order)); + CHECK_EQ(pipeline_stages.size(), original_order.size()) + << "PrimFunc " << global_symbol_ << " has original order " + << original_order.Map( + [](const auto &block) { return block->name_hint; }) + << ", but pipeline annotation is " << pipeline_stages + << " with different size"; + CHECK_EQ(pipeline_orders.size(), original_order.size()) + << "PrimFunc " << global_symbol_ << " has original order " + << original_order.Map( + [](const auto &block) { return block->name_hint; }) + << ", but pipeline annotation is " << pipeline_orders + << " with different size"; + + std::unordered_set pipeline_async_stages; + if (auto annot = + op->annotations.Get(tir::attr::software_pipeline_async_stages)) { + for (auto s : Downcast>(annot.value())) { + pipeline_async_stages.insert(s->value); + } + } + + for (size_t i = 0; i < pipeline_stages.size(); i++) { + int stage = static_cast(pipeline_stages[i]->value); + bool is_async = + pipeline_async_stages.find(stage) != pipeline_async_stages.end(); + PipelineAnnotation stage_order{ + stage, + /*order=*/static_cast(pipeline_orders[i]->value), is_async, + /*original_idx=*/static_cast(i)}; + pipeline_info.emplace(original_order[i], stage_order); + } + + ValidatePipelineBody(pipeline_info, original_order); + + // Step 4: Rewrite the pipeline body. + Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, + tvm::ffi::GetRef(op), pipeline_info, + loop_var_let_wrappers) + .BuildPipeline(); + auto apply_wrappers = [&](Stmt stmt) { + for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) { + stmt = (*it)(stmt); + } + return stmt; + }; + if (!rewrap_fns.empty()) { + if (pipeline_body_from_block) { + BlockRealize pipeline_realize = Downcast(pipeline); + Block pipeline_block = pipeline_realize->block; + { + BlockNode *block_node = pipeline_block.CopyOnWrite(); + block_node->body = apply_wrappers(block_node->body); + } + pipeline = BlockRealize(pipeline_realize->iter_values, + pipeline_realize->predicate, pipeline_block, + pipeline_realize->span); + } else { + pipeline = apply_wrappers(pipeline); + } + } + + if (const auto *realize = op->body.as()) { + const auto &block = realize->block; + for (const auto &buffer : block->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + } + return pipeline; + } + + Stmt VisitStmt_(const BlockNode *op) final { + for (const auto &buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + BlockNode *n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + + for (const auto &buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + return block; + } + + bool HasPipelineAnnotation(const ForNode *op) const { + auto it1 = op->annotations.find(tir::attr::software_pipeline_stage); + auto it2 = op->annotations.find(tir::attr::software_pipeline_order); + bool has_stage = it1 != op->annotations.end(); + bool has_order = it2 != op->annotations.end(); + if (has_stage && has_order) { + return true; + } + if (has_stage) { + LOG(FATAL) + << "ValueError: Stage of the software pipeline is not defined."; + } + if (has_order) { + LOG(FATAL) + << "ValueError: Order of the software pipeline is not defined."; + } + return false; + } + + Map buffer_data_to_buffer_; + Optional global_symbol_; +}; +} // namespace software_pipeline + +/*! + * \brief Transform annotated loops into pipelined one that parallelize + * producers and consumers. \return The IR transform pass. + */ +tir::transform::Pass InjectSoftwarePipeline() { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *fptr = f.CopyOnWrite(); + fptr->body = software_pipeline::PipelineInjector::Inject(f); + fptr->body = ConvertSSA(std::move(fptr->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline", + InjectSoftwarePipeline); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/inject_ptx_async_copy.cc b/tilelang/original/src/transform/inject_ptx_async_copy.cc new file mode 100644 index 0000000000000000000000000000000000000000..1fadefbf4ffd8b01d0eb9ed15c0eb1417146932a --- /dev/null +++ b/tilelang/original/src/transform/inject_ptx_async_copy.cc @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Replace copy from global to shared with async copy + * \file inject_ptx_async_copy.cc + */ +#include +#include +#include +#include +#include +#include +#include + +#include "storage_access.h" +#include "tir/ir/buffer_common.h" +#include "tvm/tir/stmt.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class PTXAsyncCopyInjector : public StmtMutator { +public: + Stmt VisitStmt_(const AttrStmtNode *attr) { + if (attr->attr_key == tir::attr::async_scope) { + ICHECK(in_async == false) << "Nested async scopes not supported"; + in_async = true; + auto body = this->VisitStmt(attr->body); + in_async = false; + return body; + } + return StmtMutator::VisitStmt_(attr); + } + + Stmt InjectPTX(const BufferLoadNode *load, const BufferStoreNode *store, + bool predicated = false, + const PrimExpr &predicate_value = PrimExpr()) { + if (load->buffer.scope() == "global") { + ICHECK(load->indices.size() == 1 && store->indices.size() == 1); + ICHECK(load->indices[0]->dtype.lanes() == + store->indices[0]->dtype.lanes()) + << load->indices[0] << " vs. " << store->indices[0] << " with lanes " + << load->indices[0]->dtype.lanes() << " vs. " + << store->indices[0]->dtype.lanes(); + + const int indices_lanes = load->indices[0]->dtype.lanes(); + const int bytes = indices_lanes * load->buffer->dtype.bytes(); + + if (bytes == 4 || bytes == 8 || bytes == 16) { + auto dst_elem_type = + GetPointerType(store->buffer->data->type_annotation); + auto src_elem_type = + GetPointerType(load->buffer->data->type_annotation); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "Both store and load buffer should have a pointer type " + "annotation."; + + int index_factor = 1; + if (dst_elem_type.value() != src_elem_type.value()) { + // The only case where src and dst have different dtypes is when the + // dst shared memory is a byte buffer generated by merging dynamic + // shared memory. + ICHECK(store->buffer.scope() == "shared.dyn" || + store->buffer.scope() == "shared"); + ICHECK(dst_elem_type.value() == DataType::UInt(8)); + // BufferStore/Load have the "pointer reinterpret" semantics according + // to their "value" dtype. Their "indices" are supposed to be applied + // after such pointer cast, for example: + // ((*float16)(byte_buffer))[buffer->indices] = fp16_value; To replace + // BufferStore/Load with cp.async, we need to multiply the store index + // by the byte size of the "value" dtype, to get the correct offset + // into the byte buffer. + index_factor = src_elem_type->bytes(); + } + + if (indices_lanes == 1) { + auto src_offset = load->indices[0]; + auto dst_offset = store->indices[0]; + Array args = { + store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; + // use arguments size to indicate whether or not to use predicated + // cp.async + if (predicated) { + args.push_back(predicate_value); + } + return Evaluate(Call(store->buffer->dtype, + tvm::tir::builtin::ptx_cp_async(), args)); + } + + // Predicated load don't support vectorized indexing. + if (!predicated) { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by + // merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) + + // x8(17408))] = A_global[ramp(...),1, 8)] + auto *add = store->indices[0].as(); + if (!add->a->IsInstance()) + return PrimExpr(); + if (!add->b->IsInstance()) + return PrimExpr(); + return tir::Add(add->a.as()->base, + add->b.as()->value); + } + return PrimExpr(); + }(); + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate(Call( + store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); + } + } else { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by + // merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) + + // x8(17408))] = A_global[ramp(...),1, 8)] + auto *add = store->indices[0].as(); + if (!add->a->IsInstance()) + return PrimExpr(); + if (!add->b->IsInstance()) + return PrimExpr(); + return tir::Add(add->a.as()->base, + add->b.as()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate(Call( + store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes), + predicate_value})); + } + } + } + } + return StmtMutator::VisitStmt_(store); + } + + Stmt VisitStmt_(const BufferStoreNode *store) { + bool is_shared = (store->buffer.scope() == "shared" || + store->buffer.scope() == "shared.dyn"); + if (in_async && is_shared) { + if (auto *load = store->value.as()) { + return InjectPTX(load, store); + } else if (auto *call = store->value.as()) { + // tir.if_then_else is a call to tir::builtin::if_then_else() + if (call->op.same_as(builtin::if_then_else()) && + call->args.size() == 3) { + if (auto *load = call->args[1].as()) { + // Only default value of 0 is supported since 0 is the default value + // used by cp.async ptx. @see section 9.7.8.22.3. of + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations + bool else_value_is_zero = false; + if (auto *b = call->args[2].as()) { + if (auto *f = b->value.as()) { + else_value_is_zero = f->value == 0.0f; + } else if (auto *i = b->value.as()) { + else_value_is_zero = i->value == 0; + } + } + if (auto *f = call->args[2].as()) { + else_value_is_zero = f->value == 0.0f; + } else if (auto *i = call->args[2].as()) { + else_value_is_zero = i->value == 0; + } + if (else_value_is_zero) { + return InjectPTX(load, store, true, call->args[0]); + } + } + } + } + } + return StmtMutator::VisitStmt_(store); + } + +private: + bool in_async{false}; +}; + +using namespace tir::transform; + +tvm::transform::Pass InjectPTXAsyncCopy() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + n->body = PTXAsyncCopyInjector()(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/inject_tma_barrier.cc b/tilelang/original/src/transform/inject_tma_barrier.cc new file mode 100644 index 0000000000000000000000000000000000000000..93beb15d4ff6968a74fab14b72582f7a0c6c96c5 --- /dev/null +++ b/tilelang/original/src/transform/inject_tma_barrier.cc @@ -0,0 +1,604 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tma_barrier_rewriter.cc + * \brief Rewrite TMA barriers for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" +#include "./common/attr.h" +#include "./common/collector.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace tir::transform; +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +class TmaTraitsCollector : public StmtExprVisitor { +public: + TmaTraitsCollector() { Initialize(); } + + void Initialize() { + bulk_copy_bytes = 0; + loop_extents = 1; + } + + void Collect(const Stmt &stmt) { VisitStmt(stmt); } + + PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } + +private: + void VisitExpr_(const CallNode *call) final { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + auto arg0 = call->args[0].as(); + if (call->op.same_as(tma_load()) && arg0 && + !arg0.value()->op.same_as(create_tma_descriptor())) { + // 1D TMA load has tvm_access_ptr of shared tensor in its args[0] + bulk_copy_bytes = call->args[3] * loop_extents; + } else { + Call access_ptr = Downcast(call->args[2]); + ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); + int type_bytes = access_ptr->args[0]->dtype.bytes(); + bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; + } + } + StmtExprVisitor::VisitExpr_(call); + } + + void VisitStmt_(const ForNode *op) final { + PrimExpr old_loop_evtents = loop_extents; + loop_extents *= op->extent; + StmtExprVisitor::VisitStmt_(op); + loop_extents = old_loop_evtents; + } + + PrimExpr bulk_copy_bytes = 0; + PrimExpr loop_extents = 1; +}; + +class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { +public: + static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) { + TmaExpectTxRewriter rewriter(analyzer); + f.CopyOnWrite()->body = rewriter(f->body); + return f; + } + +private: + bool inside_tma_block_{false}; + bool visited_tma_load_{false}; + IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), + IterVarType::kDataPar); + + PrimExpr makeGetBarrier(PrimExpr barrier_id) { + return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)}); + } + + Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { + auto call = Call(DataType::Handle(), mbarrier_expect_tx(), + {makeGetBarrier(std::move(barrier_id)), std::move(bytes)}); + return Evaluate(call); + } + + TmaExpectTxRewriter(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {} + + Stmt VisitStmt_(const AttrStmtNode *op) final { + + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode *op) { + // Check if this is the TMA block + bool flag = false; + if (op->condition.as()) { + flag = op->condition.as()->op.same_as(tl_shuffle_elect()); + } + if (op->condition.as() || flag) { + Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op); + + if (visited_tma_load_) { + auto then_case = op->then_case; + TmaTraitsCollector collector; + collector.Collect(then_case); + + Array stmts; + if (!is_zero(collector.BulkCopyBytes())) { + auto expect_tx = makeExpectTX(0, collector.BulkCopyBytes()); + stmts.push_back(expect_tx); + } + stmts.push_back(then_case); + if (stmts.size() == 1) { + return IfThenElse(op->condition, stmts[0], op->else_case); + } else { + auto seq_stmt = SeqStmt(stmts); + return IfThenElse(op->condition, seq_stmt, op->else_case); + } + } + visited_tma_load_ = false; + return ret; + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + auto arg0 = op->args[0].as(); + bool is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + op->op.same_as(tma_load()); + visited_tma_load_ = true; + Array new_args = op->args; + new_args.Set(is_1d_tma_load ? 2 : 1, + Call(DataType::Handle(), get_mbarrier(), + {IntImm(DataType::Int(32), 0)})); + return Call(op->dtype, op->op, new_args); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } +}; + +class TmaBarrierCollector : public IRVisitorWithAnalyzer { +public: + TmaBarrierCollector(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + + Map tma_op_to_barrier_id() { + return tma_op_to_barrier_id_; + } + Map barrier_id_to_range() { return barrier_id_to_range_; } + +private: + void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) { + if (barrier_id_to_range_.count(barrier_id)) { + auto old_extent = barrier_id_to_range_[barrier_id]; + ICHECK_EQ(old_extent->value, extent->value) + << "barrier_id: " << barrier_id << " has different extent"; + barrier_id_to_range_.Set(barrier_id, extent); + } else { + barrier_id_to_range_.Set(barrier_id, extent); + } + } + + void VisitStmt_(const EvaluateNode *op) final { + if (const auto *call = op->value.as()) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); + } else if (call->op.same_as(mbarrier_expect_tx())) { + pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); + } else if (call->op.same_as(builtin::ptx_arrive_barrier())) { + PrimExpr barrier_id = call->args[0]; + for (const auto &tma_call : pending_tma_ops_) { + tma_op_to_barrier_id_.Set(tma_call, barrier_id); + } + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); + pending_tma_ops_.clear(); + } else if (call->op.same_as(builtin::ptx_wait_barrier())) { + PrimExpr barrier_id = call->args[0]; + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); + } + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + thread_var_ = iv; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + IterVar thread_var_; + std::vector pending_tma_ops_; + Map tma_op_to_barrier_id_; + Map barrier_id_to_range_; + Map buffer_data_to_buffer_; +}; + +class TmaSequenceCollector : public IRVisitorWithAnalyzer { +public: + TmaSequenceCollector(Map tma_op_to_barrier_id) + : tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)) {} + + std::vector GetSequence() { + std::vector clear_zero_list(expect_tx_count_, false); + int zero_idx = -1; + int zero_count = 0; + + for (auto v : sequence) { + if (v == 0) { + zero_count += 1; + zero_idx += 1; + } else { + if (zero_count == 1) { + clear_zero_list[zero_idx] = expect_[zero_idx] && !has_simt_copy_; + if (clear_zero_list[zero_idx] == false) { + int begin = int_sets_[zero_idx].min().as()->value; + int end = int_sets_[zero_idx].max().as()->value; + for (int i = begin; i <= end; ++i) { + restore_barrier_ids_.push_back(i); + } + } + } else { + for (int i{zero_idx}; i > zero_idx - zero_count; --i) { + int begin = int_sets_[i].min().as()->value; + int end = int_sets_[i].max().as()->value; + for (int i = begin; i <= end; ++i) { + restore_barrier_ids_.push_back(i); + } + } + } + zero_count = 0; + } + } + + return clear_zero_list; + } + + std::vector GetRestoreBarrierIds() { return restore_barrier_ids_; } + + void VisitStmt_(const ForNode *op) final { + var_int_set_.Set(op->loop_var, + arith::IntSet::FromMinExtent(op->min, op->extent)); + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(mbarrier_expect_tx())) { + auto call_ref = tvm::ffi::GetRef(op); + if (tma_op_to_barrier_id_.count(call_ref)) { + PrimExpr e = tma_op_to_barrier_id_[call_ref].as()->args[0]; + auto int_set = arith::EvalSet(e, var_int_set_); + expect_.push_back(if_depth_ == 1); + sequence.push_back(0); + int_sets_.push_back(int_set); + expect_tx_count_ += 1; + } + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + sequence.push_back(1); + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + has_simt_copy_ = true; + } + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode *op) final { + if_depth_ += 1; + + IRVisitorWithAnalyzer::VisitStmt(op->then_case); + + if (op->else_case) { + IRVisitorWithAnalyzer::VisitStmt(op->else_case.value()); + } + if_depth_ -= 1; + } + + std::vector sequence; + int expect_tx_count_{0}; + std::vector expect_; + bool has_simt_copy_{false}; + std::vector restore_barrier_ids_; + int if_depth_{0}; + Map tma_op_to_barrier_id_; + arith::Analyzer *analyzer_{}; + Map var_int_set_; + std::vector int_sets_; +}; + +class BarrierCreationRewriter : public StmtExprMutator { +public: + BarrierCreationRewriter(std::vector restore_barrier_ids, + PrimExpr producer_thread_extent, + int ensure_min_count = 0, + PrimExpr default_barrier_thread_count = 1) + : restore_barrier_ids_(std::move(restore_barrier_ids)), + producer_thread_extent_(std::move(producer_thread_extent)), + ensure_min_count_(ensure_min_count), + default_barrier_thread_count_(std::move(default_barrier_thread_count)) { + } + + PrimExpr VisitExpr_(const CallNode *op) { + if (op->op.same_as(create_list_of_mbarrier())) { + size_t cur_n = op->args.size(); + size_t need_n = + std::max(cur_n, static_cast(ensure_min_count_)); + + // Mark barriers to restore across the full needed length, not just the + // original length, so newly appended entries can be restored as well. + std::vector replace(need_n, false); + for (auto &id : restore_barrier_ids_) { + if (id >= 0 && static_cast(id) < replace.size()) { + replace[id] = true; + } + } + + Array new_args; + new_args.reserve(need_n); + + // Preserve/override existing entries + for (size_t i{0}; i < cur_n; ++i) { + if (replace[i]) { + new_args.push_back(producer_thread_extent_); + } else { + new_args.push_back(op->args[i]); + } + } + // Append additional barriers if required + for (size_t i = cur_n; i < need_n; ++i) { + if (replace[i]) { + new_args.push_back(producer_thread_extent_); + } else { + new_args.push_back(default_barrier_thread_count_); + } + } + + return Call(op->dtype, op->op, new_args); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + +private: + std::vector restore_barrier_ids_; + PrimExpr producer_thread_extent_; + int ensure_min_count_{0}; + PrimExpr default_barrier_thread_count_{1}; +}; + +// we trust mbarrier_wait_parity to be correct +class TmaBarrierRewriter : public IRMutatorWithAnalyzer { +public: + TmaBarrierRewriter(arith::Analyzer *analyzer, + Map tma_op_to_barrier_id, + Map barrier_id_to_range, + bool has_create_list_of_mbarrier) + : IRMutatorWithAnalyzer(analyzer), + tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)), + barrier_id_to_range_(std::move(barrier_id_to_range)), + has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {} + + static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) { + auto buffer_lca = DetectBufferAccessLCA(f); + Map buffer_data_to_buffer_; + for (auto [buffer, _] : buffer_lca) + buffer_data_to_buffer_.Set(buffer->data, buffer); + f = TmaExpectTxRewriter::Rewrite(f, analyzer); + TmaBarrierCollector collector(buffer_data_to_buffer_); + collector(f->body); + bool has_create_list_of_mbarrier = false; + PostOrderVisit(f->body, [&](const ObjectRef &node) { + if (const auto *call = node.as()) { + if (call->op.same_as(create_list_of_mbarrier())) { + has_create_list_of_mbarrier = true; + } else if (call->op.same_as(builtin::ptx_init_barrier_thread_count())) { + has_create_list_of_mbarrier = true; + } + } + }); + TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(), + collector.barrier_id_to_range(), + has_create_list_of_mbarrier); + f.CopyOnWrite()->body = rewriter(f->body); + // Compute the minimum number of barriers actually referenced in the body + // after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA). + struct GetMbarrierMaxIdxCollector : public StmtExprVisitor { + int max_idx{-1}; + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(get_mbarrier())) { + if (op->args.size() == 1) { + if (const auto *imm = op->args[0].as()) { + max_idx = std::max(max_idx, static_cast(imm->value)); + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + }; + + GetMbarrierMaxIdxCollector max_idx_collector; + max_idx_collector(f->body); + int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count + + // For simple TMA-only producers, default barrier arrive count should be 1 + // (only the elected leader performs the TMA arrive/expect). + auto barrier_creation_rewriter = BarrierCreationRewriter( + rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_, + ensure_min_count, Integer(1)); + f.CopyOnWrite()->body = barrier_creation_rewriter(f->body); + return f; + } + +private: + Stmt VisitStmt_(const BlockNode *op) { + auto block = tvm::ffi::GetRef(op); + if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() && + op->name_hint == MainBlockName) { + ICHECK(false) << "Please declare create_list_of_mbarrier."; + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode *op) { + if (first_if) { + if (op->condition.as()) { + producer_thread_extent_ = + thread_var_->dom->extent - op->condition.as()->b; + } + TmaSequenceCollector collector(tma_op_to_barrier_id_); + collector(op->then_case); + clear_expect_list_ = collector.GetSequence(); + restore_barrier_ids_ = collector.GetRestoreBarrierIds(); + first_if = false; + + is_producer_ = true; + + auto then_case = StmtExprMutator::VisitStmt(op->then_case); + + is_producer_ = false; + Stmt else_case; + if (op->else_case.defined()) + else_case = StmtExprMutator::VisitStmt(op->else_case.value()); + return IfThenElse(op->condition, then_case, else_case); + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "kWarpSpecializationScope") { + has_warp_specialization_ = true; + first_if = true; + } else if (op->attr_key == tir::attr::thread_extent && + Downcast(op->node)->thread_tag == "threadIdx.x") { + thread_var_ = Downcast(op->node); + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + auto call_ref = tvm::ffi::GetRef(op); + if (!tma_op_to_barrier_id_.count(call_ref)) { + // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id) + // so codegen can emit mbarrier[index]. This handles degenerate + // producer-only kernels where no arrive() is seen and mapping is empty. + auto arg0 = op->args[0].as(); + bool is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + if (is_1d_tma_load && op->args.size() >= 3) { + if (const auto *imm = op->args[2].as()) { + Array new_args = op->args; + new_args.Set(2, Call(DataType::Handle(), get_mbarrier(), + {IntImm(DataType::Int(32), + static_cast(imm->value))})); + return Call(op->dtype, op->op, new_args); + } + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + auto barrier_id = tma_op_to_barrier_id_[call_ref]; + auto new_args = op->args; + auto arg0 = op->args[0].as(); + auto is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + if (is_1d_tma_load) { + new_args.Set(2, barrier_id); + } else { + new_args.Set(1, barrier_id); + } + return Call(op->dtype, op->op, new_args); + } else if (op->op.same_as(mbarrier_expect_tx())) { + auto call_ref = tvm::ffi::GetRef(op); + if (!tma_op_to_barrier_id_.count(call_ref)) { + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + auto barrier_id = tma_op_to_barrier_id_[call_ref]; + auto new_args = op->args; + new_args.Set(0, barrier_id); + if (!has_warp_specialization_) + clear_arrive_ = false; + else + clear_arrive_ = clear_expect_list_[cur_expect_idx_++]; + if (clear_arrive_) { + return Call(op->dtype, builtin::ptx_arrive_barrier_expect_tx(), + new_args); + } + return Call(op->dtype, op->op, new_args); + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + if (clear_arrive_) { + clear_arrive_ = false; + return 0; + } + // by default, all threads must wait. + auto new_args = op->args; + return Call(op->dtype, op->op, new_args); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + Map tma_op_to_barrier_id_; + Map barrier_id_to_range_; + bool has_create_list_of_mbarrier_; + bool clear_arrive_{false}; + bool first_if{false}, has_warp_specialization_{false}, is_producer_{false}; + IterVar thread_var_; + int tma_expect_tx_{0}, cur_expect_idx_{0}; + std::vector clear_expect_list_; + std::vector restore_barrier_ids_; + PrimExpr producer_thread_extent_; +}; + +tvm::transform::Pass InjectTmaBarrier() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + // Check if function only uses threadIdx.x before proceeding + if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { + LOG(WARNING) << "InjectTmaBarrier will be disabled because the program " + "uses thread tags other than threadIdx.x\n" + << "If you want to use TMA barrier, please refactor " + "your program to use threadIdx.x only"; + // Return original function unchanged if other thread tags are found + return f; + } + arith::Analyzer analyzer; + return TmaBarrierRewriter::Rewrite(f, &analyzer); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/layout_inference.cc b/tilelang/original/src/transform/layout_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..1af8161474336c4df4e62d9916fb618b28b9c24c --- /dev/null +++ b/tilelang/original/src/transform/layout_inference.cc @@ -0,0 +1,1275 @@ +/*! + * \file layout_inference.cc + * \brief infer the fragment/shared memory layout + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../layout/utils.h" +#include "../op/copy.h" +#include "../op/parallel.h" +#include "../op/region.h" +#include "../target/utils.h" + +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" +#include "common/loop_fusion_utils.h" +#include "common/loop_parallel_transform_utils.h" +#include "common/union_find.h" +#include "layout_reducer.h" +#include "loop_partition.h" +#include "loop_vectorize.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief collect the mapping from the buffer var to it allocated buffer + */ +class ThreadBindingCollector : public StmtExprVisitor { +public: + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + thread_binding_[iv->var.get()] = iv; + } + StmtExprVisitor::VisitStmt_(op); + } + + // The thread binding map + std::unordered_map thread_binding_; +}; + +using namespace tir; +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +struct LayoutInferenceResult { + Map layout_map; + Map for_map; + Map predicate_map; +}; + +class BufferUseDefCollector : public IRVisitorWithAnalyzer { +public: + BufferUseDefCollector(bool skip_thread_partition) + : skip_thread_partition_(skip_thread_partition) {} + + using arith::IRVisitorWithAnalyzer::IRVisitorWithAnalyzer; + + void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue, + LayoutMap &layout_map, const LayoutMap &strict_layout_map, + std::deque &q, std::vector &in_queue) { + auto num_infer = infer_list_.size(); + + // Range check for cur_infer_id + ICHECK_GE(cur_infer_id, 0) << "cur_infer_id is negative, which is invalid."; + ICHECK_LT(cur_infer_id, num_infer) + << "cur_infer_id " << cur_infer_id << " is out of range, must be < " + << num_infer << "."; + + // Make sure we can safely access infer_list_[cur_infer_id] and + // thread_var_vec_[cur_infer_id] + auto &next = infer_list_[cur_infer_id]; + auto iter_var = thread_var_vec_[cur_infer_id]; + auto thread_bounds = thread_bounds_vec_[cur_infer_id]; + arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get(); + auto buffer_oob = buffer_oob_vec_[cur_infer_id]; + // Double-check that 'next' is valid + ICHECK(next.defined()) << "infer_list_[" << cur_infer_id + << "] is null inside run_infer_step."; + + // Check iter_var->dom and dom->extent + ICHECK(iter_var.defined()) + << "thread_var_vec_[" << cur_infer_id << "] is not defined."; + ICHECK(iter_var->dom.defined()) + << "iter_var->dom is not defined for infer_list_[" << cur_infer_id + << "]."; + ICHECK(iter_var->dom->extent.defined()) + << "iter_var->dom->extent is not defined for infer_list_[" + << cur_infer_id << "]."; + + const int64_t *extent_ptr = as_const_int(iter_var->dom->extent); + ICHECK(extent_ptr != nullptr) + << "iter_var->dom->extent is not a constant integer, which is " + "required for layout inference."; + + // Run InferLayout + auto updates = next->InferLayout(LayoutInferArgs{target_, + thread_bounds, + layout_map, + cur_analyzer, + buffer_oob, + {}, + let_var_to_expr_}, + level); + + // Process the returned updates + for (const auto &[buffer, layout] : updates) { + // Basic validity checks + ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; + ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; + + // Helper: propagate inferred layout to alias buffers (same data Var) + auto propagate_alias = [&](const Buffer &src_buffer, + const Layout &src_layout) { + if (!buffer_data_to_buffers_.count(src_buffer->data)) + return; + const auto &siblings = buffer_data_to_buffers_[src_buffer->data]; + for (const auto &sib : siblings) { + if (sib.same_as(src_buffer)) + continue; + bool shapes_equal = + src_layout->InputShape().size() == sib->shape.size(); + if (shapes_equal) { + for (size_t i = 0; i < src_layout->InputShape().size(); ++i) { + if (!analyzer_.CanProveEqual(src_layout->InputShape()[i], + sib->shape[i])) { + shapes_equal = false; + break; + } + } + } + Layout target_layout = + shapes_equal + ? src_layout + : src_layout->Reshape(sib->shape, &analyzer_, + Integer(src_buffer->dtype.bytes()), + Integer(sib->dtype.bytes())); + if (layout_map.count(sib)) { + ICHECK(target_layout->IsEqual(layout_map[sib].get())) + << "Get different layout for alias buffer " << sib + << " (data-shared with " << src_buffer + << ")\n current: " << target_layout->DebugOutput() + << "\n previous: " << layout_map[sib]->DebugOutput(); + } else { + layout_map.Set(sib, target_layout); + if (update_queue && use_list_.count(sib)) { + for (int idx : use_list_[sib]) { + EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map); + } + } + } + } + }; + + if (layout_map.count(buffer)) { + // If new layout contains the old one, update map + if (buffer.scope() == "local.fragment" && + level != InferLevel::kStrict && !strict_layout_map.count(buffer)) { + // Actually this test has been done in ParallelOp::InferLayout + // already. Just do it again to avoid missing implementations in other + // `TileOperator`s. + + auto dst_layout_opt = layout.as(); + ICHECK(dst_layout_opt.has_value()) + << "Failed to cast layout to Fragment for buffer " << buffer + << ", layout type is " << layout->GetTypeKey(); + const auto &dst_layout = dst_layout_opt.value(); + auto src_layout_opt = layout_map[buffer].as(); + ICHECK(src_layout_opt.has_value()) + << "Failed to cast layout_map[buffer] to Fragment for buffer " + << buffer << ", layout type is " + << layout_map[buffer]->GetTypeKey(); + const auto &src_layout = src_layout_opt.value(); + ICHECK(dst_layout->InputDim() == src_layout->InputDim()); + Array indices; + indices.reserve(dst_layout->InputDim()); + arith::Analyzer inner_analyzer; + for (int i = 0; i < dst_layout->InputDim(); ++i) { + auto x = InputPlaceholder(i); + indices.push_back(x); + // should be literal - literal = 0, any analyzer will work + ICHECK(is_zero(inner_analyzer.Simplify( + dst_layout->InputShape()[i] - src_layout->InputShape()[i]))); + inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); + } + if (ProveFragmentContains(src_layout, dst_layout, indices, indices, + inner_analyzer)) { + layout_map.Set(buffer, layout); + // Propagate to alias buffers as well + propagate_alias(buffer, layout); + continue; + } + } + // If already in map, ensure they are structurally equal + ICHECK(layout->IsEqual(layout_map[buffer].get())) + << "Get different layout for " << buffer + << "\n current layout: " << layout->DebugOutput() + << "\n previous layout: " << layout_map[buffer]->DebugOutput(); + // Ensure aliases are consistent too + propagate_alias(buffer, layout); + } else { + // Otherwise, update map + layout_map.Set(buffer, layout); + // Propagate to alias buffers (may enqueue their users) + propagate_alias(buffer, layout); + if (!update_queue) + continue; + + // Check if buffer exists in use_list_ + if (!use_list_.count(buffer)) { + LOG(WARNING) << "Layout inference failed for buffer " << buffer + << ". " + << "The buffer cannot be inferred with current layout " + "inference rules."; + continue; + } + + // Push back into BFS queue + for (int idx : use_list_[buffer]) { + ICHECK_GE(idx, 0) + << "Index in use_list_ for buffer " << buffer << " is negative."; + ICHECK_LT(idx, num_infer) + << "Index in use_list_ for buffer " << buffer + << " out of range: " << idx << " >= " << num_infer << "."; + + EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map); + } + } + } + }; + + void FinishInferQueue(InferLevel level, LayoutMap &layout_map, + const LayoutMap &strict_layout_map, std::deque &q, + std::vector &in_queue) { + auto num_infer = infer_list_.size(); + + while (!q.empty()) { + int cur_infer_id = q.front(); + q.pop_front(); + // Range check again, just to be safe + ICHECK_GE(cur_infer_id, 0); + ICHECK_LT(cur_infer_id, num_infer); + + in_queue[cur_infer_id] = false; + RunInferStep(cur_infer_id, level, true, layout_map, strict_layout_map, q, + in_queue); + } + }; + + LayoutInferenceResult Run() { + // Basic consistency check: infer_list_ and thread_var_vec_ should have the + // same size + ICHECK_EQ(infer_list_.size(), thread_var_vec_.size()) + << "Size mismatch: infer_list_ and thread_var_vec_ must match in " + "length."; + ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size()) + << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in " + "length."; + ICHECK_EQ(analyzer_vec_.size(), infer_list_.size()) + << "Size mismatch: analyzer_vec_ and infer_list_ must match in " + "length."; + ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size()) + << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " + "length."; + + DLOG(INFO) << "[InferLayout] all participating operators:" << '\n'; + for (int i = 0; i < infer_list_stmt_.size(); ++i) { + DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n'; + } + + // If needed, you can also check that annotated_layout_map_ is not empty, or + // anything else relevant to your setup. + + // Copy the annotated layout map to local variable + Map layout_map = annotated_layout_map_; + Map strict_layout_map; + int num_infer = infer_list_.size(); + + // Prepare BFS queue for iterative inference + std::deque q; + std::vector in_queue(num_infer, true); + for (int i = 0; i < num_infer; i++) { + // Check that each infer_list_ entry is valid + ICHECK(infer_list_[i].defined()) + << "infer_list_[" << i + << "] is null. The inference object is not allocated properly."; + + // Check that each thread_var_vec_ entry is defined + if (!thread_var_vec_[i].defined() && skip_thread_partition_) { + thread_var_vec_[i] = thread_var_; + } + q.push_back(i); + } + + // step 1: infer strict layout + for (int i = 0; i < num_infer; i++) { + RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map, + q, in_queue); + } + + for (const auto &[buffer, layout] : layout_map) { + strict_layout_map.Set(buffer, layout); + } + + // step 2: infer common layout with BFS + FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q, + in_queue); + + // step 3: relax constraints to free and re-run + InferInFreeMode(layout_map, strict_layout_map); + + // step 4: finalize alias layouts by Var + // For each storage var, if any buffer in the group has a layout, + // propagate (reshape if needed) to the rest to ensure completeness. + for (const auto &[var, buffers] : buffer_data_to_buffers_) { + // Find a representative with existing layout + Optional rep; + Optional rep_layout; + for (const auto &buf : buffers) { + if (layout_map.count(buf)) { + rep = buf; + rep_layout = layout_map[buf]; + break; + } + } + if (!rep_layout.defined()) + continue; + for (const auto &buf : buffers) { + if (!layout_map.count(buf)) { + bool shapes_equal = + rep_layout.value()->InputShape().size() == buf->shape.size(); + if (shapes_equal) { + for (size_t i = 0; i < rep_layout.value()->InputShape().size(); + ++i) { + if (!analyzer_.CanProveEqual(rep_layout.value()->InputShape()[i], + buf->shape[i])) { + shapes_equal = false; + break; + } + } + } + + Layout reshaped = shapes_equal + ? rep_layout.value() + : rep_layout.value()->Reshape( + buf->shape, &analyzer_, + Integer(rep.value()->dtype.bytes()), + Integer(buf->dtype.bytes())); + layout_map.Set(buf, reshaped); + } + } + } + + // Check that all local.fragment buffers have inferred layouts + for (const auto &[buffer, _] : use_list_) { + if (buffer.scope() == "local.fragment") { + ICHECK_NE(layout_map.count(buffer), 0) + << "The layout for fragment " << buffer + << " can not be inferred correctly."; + } + } + + // Collect layout info for For nodes + Map for_map; + Map predicate_map; + ICHECK(infer_list_.size() == thread_var_vec_.size()) + << "infer_list_ and thread_var_vec_ size mismatch"; + for (int i = 0; i < infer_list_.size(); i++) { + TileOperator base_infer = std::move(infer_list_[i]); + auto thread_var = thread_var_vec_[i]; + + // Check if base_infer is valid + ICHECK(base_infer.defined()) << "Null pointer encountered in " + "infer_list_ while collecting for_map."; + if (auto for_infer = base_infer.as()) { + // Check that the loop layout is defined + ICHECK(for_infer->GetLoopLayout().defined()) + << "The Layout for Parallel for cannot be inferred correctly:\n" + << for_infer->GetRoot(); + for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout()); + // thread_var_ should be defined if we rely on it + ICHECK(thread_var.defined()) + << "thread_var is not defined. Cannot retrieve predicate."; + + if (auto predicate = for_infer->GetPredicate(thread_var->var)) { + predicate_map.Set(for_infer->GetRoot(), predicate.value()); + } + } + } + + return {layout_map, for_map, predicate_map}; + } + + void Collect(const PrimFunc &f) { + for (const auto &[_, buffer] : f->buffer_map) { + if (buffer_data_to_buffers_.count(buffer->data)) { + auto buffers = buffer_data_to_buffers_[buffer->data]; + buffers.push_back(buffer); + buffer_data_to_buffers_.Set(buffer->data, buffers); + } else { + buffer_data_to_buffers_.Set(buffer->data, {buffer}); + } + } + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) + << "Layout_Inference: Require the target attribute"; + target_ = target.value(); + this->operator()(f->body); + } + +private: + Map GetBufferMap() const { + Map buffer_map; + for (const auto &[var, buffers] : buffer_data_to_buffers_) { + // Use the first buffer for each var + // TODO(lei): phaseout buffer_map in future. + if (!buffers.empty()) { + buffer_map.Set(var, buffers[0]); + } + } + return buffer_map; + } + + // Return true if all buffers that this op (idx) touches already have + // inferred layouts in layout_map. Used to prioritize enqueue order. + bool ShouldPrioritize(int idx, const LayoutMap &layout_map) const { + auto it = op_touched_buffers_.find(idx); + if (it == op_touched_buffers_.end() || it->second.empty()) + return false; + for (const auto &buf : it->second) { + if (!layout_map.count(buf)) + return false; + } + return true; + } + + // Enqueue idx to q with priority if all its buffers already + // have layouts. Also guards against duplicates and self-enqueue. + void EnqueueWithPriority(int idx, std::deque &q, + std::vector &in_queue, int cur_infer_id, + const LayoutMap &layout_map) const { + if (idx == cur_infer_id) + return; + if (idx < 0 || idx >= static_cast(in_queue.size())) + return; + if (in_queue[idx]) + return; + in_queue[idx] = true; + if (ShouldPrioritize(idx, layout_map)) { + q.push_front(idx); + } else { + q.push_back(idx); + } + } + + void VisitExpr_(const CallNode *op) final { + IRVisitorWithAnalyzer::VisitExpr_(op); + // Do not analysis the call node to the global function. + if (op->op.as()) + return; + + auto p = ParseOperator(tvm::ffi::GetRef(op)); + if (p.defined()) { + for (const auto &arg : op->args) { + if (auto buffer = getBufferFromAccessPtr(arg)) { + addToUseList(buffer.value()); + } else if (auto buffer = getBufferFromRegion(arg)) { + addToUseList(buffer.value()); + } + // Check if the argument uses any LetStmt variables that reference + // fragment buffers. If so, add those buffers to the use list. + // This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0], ...) + CollectFragmentBuffersFromExpr(arg); + } + // Compute thread_var_ and thread_bounds_ + thread_var_vec_.push_back(thread_var_); + if (analyzer_.const_int_bound.IsBound(thread_var_->var)) { + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto min_value = const_int_bound->min_value; + auto max_value = const_int_bound->max_value; + auto extent = max_value - min_value + 1; + auto dtype = thread_var_->var.dtype(); + thread_bounds_vec_.push_back(Range::FromMinExtent( + IntImm(dtype, min_value), IntImm(dtype, extent))); + } else { + thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); + } + analyzer_vec_.push_back(analyzer_.Clone()); + + // Compute buffer oob for each buffer in the op + if (const auto *copy = p.as()) { + auto src_tensor = copy->src; + auto dst_tensor = copy->dst; + auto src_range = copy->src_range; + auto dst_range = copy->dst_range; + bool src_oob = false; + bool dst_oob = false; + for (size_t i = 0; i < src_range.size(); i++) { + if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <= + src_tensor->shape[i], + arith::ProofStrength::kSymbolicBound)) { + src_oob = true; + break; + } + } + for (size_t i = 0; i < dst_range.size(); i++) { + if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <= + dst_tensor->shape[i], + arith::ProofStrength::kSymbolicBound)) { + dst_oob = true; + break; + } + } + buffer_oob_vec_.push_back(src_oob || dst_oob); + } else { + buffer_oob_vec_.push_back(false); + } + + // Add the tile operator to infer_list_ + infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); + infer_list_.push_back(std::move(p)); + } + } + + Optional getBufferFromAccessPtr(const PrimExpr &expr) { + if (auto bl = expr.as()) { + return bl->buffer; + } + auto call = expr.as(); + if (!call) { + return std::nullopt; + } + if (call->op.same_as(builtin::tvm_access_ptr())) { + auto var_opt = call->args[1].as(); + if (!var_opt.has_value()) { + LOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: " + << call->args[1]->GetTypeKey(); + return std::nullopt; + } + const auto &var = var_opt.value(); + if (buffer_data_to_buffers_.count(var)) { + const auto &buffers = buffer_data_to_buffers_[var]; + if (!buffers.empty()) { + return buffers[0]; // Return the first buffer + } + } + return std::nullopt; + } + return std::nullopt; + } + + Optional getBufferFromRegion(const PrimExpr &expr) { + if (auto call = expr.as()) { + if (call->op.same_as(RegionOp::Get())) { + if (auto bl = call->args[0].as()) { + return bl->buffer; + } + return std::nullopt; + } + } + return std::nullopt; + } + + void addToUseList(const Buffer &buffer) { + // buffer scope must be local.fragment + if (buffer.scope() != "local.fragment") { + return; + } + int infer_idx = infer_list_.size(); + if (use_list_.find(buffer) == use_list_.end()) { + use_list_[buffer] = {}; + } + use_list_[buffer].push_back(infer_idx); + + // Track which buffers this op (infer_idx) touches for prioritization. + // Avoid duplicates. + auto &vec = op_touched_buffers_[infer_idx]; + bool exists = false; + for (const auto &b : vec) { + if (b.same_as(buffer)) { + exists = true; + break; + } + } + if (!exists) + vec.push_back(buffer); + } + + void VisitStmt_(const ForNode *op) final { + if (op->kind == ForKind::kParallel) { + auto infer = ParallelOp(tvm::ffi::GetRef(op)); + for (const auto &[buffer, _] : infer->GetIndiceMap()) { + addToUseList(buffer); + } + + PostOrderVisit(op->body, [this](const ObjectRef &node) { + if (auto *buffer_load = node.as()) { + if (buffer_load->buffer.defined() && + buffer_load->buffer->data.defined()) { + if (buffer_data_to_buffers_.count(buffer_load->buffer->data)) { + // Check if this buffer is already in the list + auto buffers = buffer_data_to_buffers_[buffer_load->buffer->data]; + bool found = false; + for (const auto &buf : buffers) { + if (buf.same_as(buffer_load->buffer)) { + found = true; + break; + } + } + if (!found) { + buffers.push_back(buffer_load->buffer); + buffer_data_to_buffers_.Set(buffer_load->buffer->data, buffers); + DLOG(INFO) << "[LayoutInference] BufferStore: added buffer " + << buffer_load->buffer + << " buffer.get() = " << buffer_load->buffer.get() + << " data = " << buffer_load->buffer->data.get(); + } + } else { + buffer_data_to_buffers_.Set(buffer_load->buffer->data, + {buffer_load->buffer}); + DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " + << buffer_load->buffer + << " buffer.get() = " << buffer_load->buffer.get() + << " data = " << buffer_load->buffer->data.get(); + } + } + } else if (auto *buffer_store = node.as()) { + if (buffer_store->buffer.defined() && + buffer_store->buffer->data.defined()) { + if (buffer_data_to_buffers_.count(buffer_store->buffer->data)) { + auto buffers = + buffer_data_to_buffers_[buffer_store->buffer->data]; + bool found = false; + for (const auto &buf : buffers) { + if (buf.same_as(buffer_store->buffer)) { + found = true; + break; + } + } + if (!found) { + buffers.push_back(buffer_store->buffer); + buffer_data_to_buffers_.Set(buffer_store->buffer->data, + buffers); + DLOG(INFO) << "[LayoutInference] BufferStore: added buffer " + << buffer_store->buffer + << " buffer.get() = " << buffer_store->buffer.get() + << " data = " << buffer_store->buffer->data.get(); + } + } else { + buffer_data_to_buffers_.Set(buffer_store->buffer->data, + {buffer_store->buffer}); + DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " + << buffer_store->buffer + << " buffer.get() = " << buffer_store->buffer.get() + << " data = " << buffer_store->buffer->data.get(); + } + } + } + }); + infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); + infer_list_.push_back(std::move(infer)); + thread_var_vec_.push_back(thread_var_); + if (thread_var_.defined() && + analyzer_.const_int_bound.IsBound(thread_var_->var)) { + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto dtype = thread_var_->var.dtype(); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + thread_bounds_vec_.push_back(Range::FromMinExtent( + IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent))); + } else { + thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); + } + analyzer_vec_.push_back(analyzer_.Clone()); + buffer_oob_vec_.push_back(false); + } else { + IRVisitorWithAnalyzer::VisitStmt(op->body); + } + } + + void VisitStmt_(const BlockNode *op) final { + for (auto buffer : op->alloc_buffers) { + if (buffer_data_to_buffers_.count(buffer->data)) { + auto buffers = buffer_data_to_buffers_[buffer->data]; + buffers.push_back(buffer); + buffer_data_to_buffers_.Set(buffer->data, buffers); + } else { + buffer_data_to_buffers_.Set(buffer->data, {buffer}); + } + } + + // First, visit the block body to collect all buffers from + // BufferLoad/BufferStore + IRVisitorWithAnalyzer::VisitStmt_(op); + + // After visiting, apply layouts to all collected buffers + if (op->annotations.count(attr::kLayoutMap)) { + // Check if the layout map is Map + auto map = + op->annotations.Get(attr::kLayoutMap)->as>().value(); + for (const auto &[var, layout] : map) { + ICHECK(buffer_data_to_buffers_.count(var)) + << "buffer " << var << " is not found in the block"; + const auto &buffers = buffer_data_to_buffers_[var]; + ICHECK(!buffers.empty()) << "buffer list for " << var << " is empty"; + // Apply layout to all buffers associated with this var + for (const auto &buffer : buffers) { + + // Reshape the layout to match the buffer's shape + // Check if shapes are structurally equal + bool shapes_equal = + layout->InputShape().size() == buffer->shape.size(); + if (shapes_equal) { + for (size_t i = 0; i < layout->InputShape().size(); ++i) { + if (!analyzer_.CanProveEqual(layout->InputShape()[i], + buffer->shape[i])) { + shapes_equal = false; + break; + } + } + } + + if (shapes_equal) { + annotated_layout_map_.Set(buffer, layout); + } else { + // Use the first buffer sharing this var as the base for dtype ratio + int base_bytes = buffers[0]->dtype.bytes(); + auto reshaped_layout = + layout->Reshape(buffer->shape, &analyzer_, Integer(base_bytes), + Integer(buffer->dtype.bytes())); + annotated_layout_map_.Set(buffer, reshaped_layout); + } + } + } + } + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + void VisitStmt_(const LetStmtNode *op) final { + // Record Let variable to its bound expression. + // This enables tracking fragment buffer accesses through let bindings. + let_var_to_expr_.Set(op->var, op->value); + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + // Helper: recursively collect fragment buffers from an expression, + // following let bindings chain. + void CollectFragmentBuffersFromExpr(const PrimExpr &expr) { + PostOrderVisit(expr, [this](const ObjectRef &node) { + if (auto bl = node.as()) { + if (bl->buffer.defined() && bl->buffer.scope() == "local.fragment") { + addToUseList(bl->buffer); + } + } else if (auto var_node = node.as()) { + auto var = tvm::ffi::GetRef(var_node); + if (let_var_to_expr_.count(var)) { + CollectFragmentBuffersFromExpr(let_var_to_expr_[var]); + } + } + }); + } + + void VisitExpr_(const BufferLoadNode *op) final { + // Collect buffer from BufferLoad + if (op->buffer.defined() && op->buffer->data.defined()) { + if (buffer_data_to_buffers_.count(op->buffer->data)) { + // Check if this buffer is already in the list + auto buffers = buffer_data_to_buffers_[op->buffer->data]; + bool found = false; + for (const auto &buf : buffers) { + if (buf.same_as(op->buffer)) { + found = true; + break; + } + } + if (!found) { + buffers.push_back(op->buffer); + buffer_data_to_buffers_.Set(op->buffer->data, buffers); + DLOG(INFO) << "[LayoutInference] BufferLoad: added buffer " + << op->buffer << " buffer.get() = " << op->buffer.get() + << " data = " << op->buffer->data.get(); + } + } else { + buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer}); + DLOG(INFO) << "[LayoutInference] BufferLoad: new buffer " << op->buffer + << " buffer.get() = " << op->buffer.get() + << " data = " << op->buffer->data.get(); + } + } + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + // Collect buffer from BufferStore + if (op->buffer.defined() && op->buffer->data.defined()) { + if (buffer_data_to_buffers_.count(op->buffer->data)) { + // Check if this buffer is already in the list + auto buffers = buffer_data_to_buffers_[op->buffer->data]; + bool found = false; + for (const auto &buf : buffers) { + if (buf.same_as(op->buffer)) { + found = true; + break; + } + } + if (!found) { + buffers.push_back(op->buffer); + buffer_data_to_buffers_.Set(op->buffer->data, buffers); + DLOG(INFO) << "[LayoutInference] BufferStore: added buffer " + << op->buffer << " buffer.get() = " << op->buffer.get() + << " data = " << op->buffer->data.get(); + } + } else { + buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer}); + DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " << op->buffer + << " buffer.get() = " << op->buffer.get() + << " data = " << op->buffer->data.get(); + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + Map> buffer_data_to_buffers_; + // Map from LetStmt variable to its bound expression + Map let_var_to_expr_; + std::vector infer_list_stmt_; + std::vector infer_list_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + use_list_; + // Per-op list of buffers it touches (fragment scope), used for prioritization + std::unordered_map> op_touched_buffers_; + // This is a workaround for cpu backend, + // we need to define a thread_var for the serial loop. + IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), + IterVarType::kDataPar); + std::vector thread_var_vec_; + std::vector thread_bounds_vec_; + std::vector> analyzer_vec_; + std::vector buffer_oob_vec_; + Target target_; + LayoutMap annotated_layout_map_; + bool skip_thread_partition_{false}; + + std::vector BackupInferList() { + std::vector back_infer_list; + back_infer_list.reserve(infer_list_.size()); + for (auto &&p : infer_list_) { + back_infer_list.push_back(p->Clone()); + } + return back_infer_list; + } + + void InferInFreeMode(LayoutMap &layout_map, + const LayoutMap &strict_layout_map) { + + DLOG(INFO) << "Enforced layout maps:" << '\n'; + for (auto &&[k, v] : layout_map) { + DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n'; + } + DLOG(INFO) << '\n'; + + // Group operators into connected components + UnionFind uf; + for (int i = 0; i < infer_list_.size(); i++) { + uf.MakeSet(i); + } + for (const auto &[buffer, infer_indices] : use_list_) { + if (infer_indices.empty()) + continue; + + // Union all infer_list_ indices that share the same Buffer object + int first_idx = infer_indices[0]; + for (size_t i = 1; i < infer_indices.size(); i++) { + uf.Union(first_idx, infer_indices[i]); + } + } + // Additionally, union across buffers that share the same underlying + // buffer->data (Var). This handles cases like reshape where multiple + // Buffer objects alias the same storage. + for (const auto &[var, buffers] : buffer_data_to_buffers_) { + std::vector merged; + for (const auto &buf : buffers) { + auto it = use_list_.find(buf); + if (it != use_list_.end()) { + const auto &vec = it->second; + merged.insert(merged.end(), vec.begin(), vec.end()); + } + } + if (merged.size() > 1) { + std::sort(merged.begin(), merged.end()); + merged.erase(std::unique(merged.begin(), merged.end()), merged.end()); + int first = merged[0]; + for (size_t i = 1; i < merged.size(); ++i) { + uf.Union(first, merged[i]); + } + } + } + + std::unordered_map> components; + for (int i = 0; i < infer_list_.size(); i++) { + int root = uf.Find(i); + components[root].push_back(i); + } + // Create a map from root to buffers + std::unordered_map> components_buffers; + for (const auto &[buffer, infer_indices] : use_list_) { + int root = uf.Find(infer_indices[0]); + components_buffers[root].push_back(buffer); + } + // Keep components_buffers for debug purpose + (void)components_buffers; + + // For each component, try each op as root, and determine the least + // replicated one + std::deque q; + std::vector in_queue(infer_list_.size(), false); + + for (auto &&[root, members] : components) { + DLOG(INFO) << "======================= processing component " << root + << '\n'; + decltype(infer_list_) best_infer_list; + LayoutMap best_layout_map; + int64_t min_reg_num = INT64_MAX; + int min_reg_num_infer_root = -1; + + // Try each member as the root of inference for this component + for (int attempt_infer_root : members) { + DLOG(INFO) << "----------------------- try root " << attempt_infer_root + << " members " << members.size() << '\n'; + // Backup the current infer_list_ state + auto back_infer_list = BackupInferList(); + // Copy the current layout_map for temporary use + LayoutMap tmp_layout_map = layout_map; + bool do_update = true; + try { + // Run inference starting from attempt_infer_root + RunInferStep(attempt_infer_root, InferLevel::kFree, true, + tmp_layout_map, strict_layout_map, q, in_queue); + FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, + q, in_queue); + + // After the first search, run inference for all other members in + // order + for (int other_infer_root : members) { + if (other_infer_root != attempt_infer_root) { + RunInferStep(other_infer_root, InferLevel::kFree, true, + tmp_layout_map, strict_layout_map, q, in_queue); + FinishInferQueue(InferLevel::kFree, tmp_layout_map, + strict_layout_map, q, in_queue); + } + } + } catch (const LayoutConflictException &e) { + do_update = false; + DLOG(INFO) << "attempt failed due to LayoutConflictException " + << e.what() << '\n'; + } catch (const NormalizeIterException &e) { + do_update = false; + DLOG(INFO) << "attempt failed due to NormalizeIterException " + << e.what() << '\n'; + } catch (const LoopLayoutInjectiveException &e) { + do_update = false; + DLOG(INFO) << "attempt failed due to LoopLayoutInjectiveException " + << e.what() << '\n'; + } + + if (do_update) { + // Compute the total register number for this layout + int64_t reg_num = 0; + for (const auto &[buffer, layout] : tmp_layout_map) { + if (auto frag = layout.as()) { + int64_t frag_reg_num = 1; + for (auto i : frag.value()->OutputShape()) { + auto pci = as_const_int(i); + ICHECK(pci != nullptr) + << "Can not use non-constant range to " + "iterate over a fragment/local " + "buffer. Non-constant shape expr is: " + << i + << ". This is possibly because you use symbolic shape when " + "accessing a fragment/local buffer."; + frag_reg_num *= *pci; + } + reg_num += frag_reg_num; + } + } + // Update the best plan if this one uses fewer registers + if (reg_num < min_reg_num || + (reg_num == min_reg_num && + attempt_infer_root < min_reg_num_infer_root)) { + best_infer_list = + BackupInferList(); // Use backup to avoid moving out infer_list_ + best_layout_map = tmp_layout_map; + min_reg_num = reg_num; + min_reg_num_infer_root = attempt_infer_root; + } + } + // Restore infer_list_ state for the next attempt + infer_list_ = std::move(back_infer_list); + } + ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n'; + // Apply the best plan for this component + infer_list_ = std::move(best_infer_list); + layout_map = best_layout_map; + DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = " + << min_reg_num_infer_root << '\n'; + } + } +}; + +class LayoutInferencer : public IRMutatorWithAnalyzer { +public: + static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) { + arith::Analyzer analyzer; + PrimFuncNode *fptr = f.CopyOnWrite(); + fptr->body = ParallelLoopFuser::Fuse(f->body); + BufferUseDefCollector collector(skip_thread_partition); + collector.Collect(f); + auto result = collector.Run(); + LayoutInferencer substituter(result, skip_thread_partition, &analyzer); + fptr->body = substituter.VisitStmt(f->body); + return f; + } + +private: + LayoutInferencer(const LayoutInferenceResult &result, + bool skip_thread_partition, arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer), result_(result), + skip_thread_partition_(skip_thread_partition) {}; + + using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; + + /** + * @brief Visit and mutate a Block node to attach inferred layout information. + * + * Converts the visited Block via the base visitor, asserts that every buffer + * allocated with scope "local.framgent" has an inferred layout in + * result_.layout_map, and attaches result_.layout_map to the Block's + * annotations under attr::kLayoutMap. + * + * If any "local.framgent" buffer lacks an entry in result_.layout_map an + * ICHECK will fail with the offending buffer printed. + * + * @return Stmt The (possibly modified) Block statement with the layout-map + * annotation set. + */ + Stmt VisitStmt_(const BlockNode *op) final { + Block block = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + + for (auto buffer : block->alloc_buffers) { + if (buffer.scope() == "local.framgent") { + ICHECK(result_.layout_map.count(buffer)) + << "Cannot inference fragment layout for " << buffer; + } + } + auto block_ptr = block.CopyOnWrite(); + block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map); + return block; + } + + /** + * @brief Visit and transform For nodes according to inferred layout + * information. + * + * If the For node is present in result_.for_map, this method applies + * loop-level layout-driven transformations: it optionally partitions the loop + * across the thread index, vectorizes the loop body, and wraps the loop with + * a predicate if one was inferred for the loop root. + * + * Detailed behavior: + * - Reads reducer information from the For node's attr::kReducerInfo + * annotation (if present) to detect reduction targets. + * - Detects register-local buffer stores (buffers with scope "local") in the + * original loop body; if only register-local stores are present the loop is + * treated as a register-local scenario and is not partitioned across + * threads. + * - Obtains the loop layout from result_.for_map[root] and, unless the loop + * is register-local or skip_thread_partition_ is set, partitions the loop via + * PartitionLoop using thread_var_ and analyzer_. + * - Scans the transformed loop body to determine whether it accesses any + * non-local buffers (scopes other than "local" or "local.fragment"). + * - Scans the transformed loop body to detect reducers (based on + * reducer_info). If a reducer is present the loop is NOT vectorized + * (reduction axes are excluded from vectorization as a conservative + * workaround). + * - If the loop has non-local accesses and no reducer, the loop is vectorized + * via VectorizeLoop. + * - If a predicate exists in result_.predicate_map for the loop root and the + * loop was partitioned, the method returns an IfThenElse surrounding the + * (possibly partitioned/vectorized) loop with that predicate; otherwise it + * returns the transformed For. + * + * @return The possibly transformed For statement (or an IfThenElse wrapping + * it) + */ + Stmt VisitStmt_(const ForNode *op) final { + Map reducer_info; + if (op->annotations.count(attr::kReducerInfo)) + reducer_info = op->annotations.Get(attr::kReducerInfo) + ->as>() + .value(); + if (!result_.for_map.count(tvm::ffi::GetRef(op))) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + // the analyzer will be modified in PartitionLoop and VectorizeLoop + // we need to save its state to prevent conflicted bindings + auto saved_analyzer = analyzer_->Clone(); + For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + auto root = tvm::ffi::GetRef(op); + // This check is a workaround to support T.Parallel for local buffers. + // For example: + // for i in T.Parallel(1024): + // A_local[i] = A_global[i] + // Here, A_local is a register-local buffer held independently by each + // thread, so explicit thread binding is not required. + bool store_into_local = false; + PostOrderVisit(root, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + if (store->buffer.scope() == "local") { + store_into_local = true; + } + // if the case is like: + // for i in T.Parallel(1024): + // A_local[i] = B_global[i] + // A_frag[i] = A_global[i] + // exception will be raise in Parallel::LayoutInference + } + }); + // This check if for the loop that only manuplates "local" buffers, + // for i in T.Parallel(1024): + // A_local[i] = B_local[i] + // Though this might be illegal + // We use PostOrderVisit to detect whether the loop only manuplates + // "local" buffers, which indicates register usage and justifies skipping + // thread binding. + bool local_register_only = true; + PostOrderVisit(root, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + if (store->buffer.scope() != "local") { + local_register_only = false; + } + } else if (const auto *load = obj.as()) { + if (load->buffer.scope() != "local") { + local_register_only = false; + } + } + }); + + auto loop_layout = result_.for_map[root]; + // FIXME: tell in-Parallel and out-of-Parallel `local`s apart + // NOTE(lei): a bit ugly, we should rethink about this part in future. + bool parallel_loop = + !skip_thread_partition_ && !local_register_only && !store_into_local; + + if (parallel_loop) { + for_node = + PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); + } + // If none thread bindings are provided, partition the loop + bool has_non_local = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (const auto *load = obj.as()) { + String scope = load->buffer.scope(); + if (scope != "local" && scope != "local.fragment") { + has_non_local = true; + } + } else if (const auto *store = obj.as()) { + String scope = store->buffer.scope(); + if (scope != "local" && scope != "local.fragment") { + has_non_local = true; + } + } + }); + // Workaround: if reducer is presented, don't vectorize loop + // Best solution should be isolate reduction axis out of vectorization + bool has_reducer = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (!has_reducer) + if (const auto *store = obj.as()) { + has_reducer = reducer_info.count(store->buffer->data) != 0; + } + }); + + // If a cast operation exists, vectorization may still be required + bool has_cast_operations = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (const auto *cast = obj.as()) { + // Check if this is a non-reducer store with Cast operation + DataType src_type = cast->value.dtype(); + DataType dst_type = cast->dtype; + bool src_ok = + src_type.is_float() || src_type.is_bfloat() || src_type.is_float8(); + bool dst_ok = + dst_type.is_float() || dst_type.is_bfloat() || dst_type.is_float8(); + if (src_ok && dst_ok && TargetIsCuda(Target::Current())) { + has_cast_operations = true; + } + } + }); + + if ((has_non_local || has_cast_operations) && !has_reducer) { + DLOG(INFO) << "Try to vectorize loop"; + for_node = VectorizeLoop(for_node, saved_analyzer.get()); + } + + if (result_.predicate_map.count(root) && parallel_loop) { + return IfThenElse(result_.predicate_map[root], for_node); + } else { + return for_node; + } + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + if (iv->thread_tag == "threadIdx.x") { + thread_var_ = iv; + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + +private: + const LayoutInferenceResult result_; + IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), + IterVarType::kDataPar); + bool skip_thread_partition_{false}; +}; + +tvm::transform::Pass LayoutInference() { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body); + ThreadBindingCollector collector; + collector(f->body); + bool has_thread_binding = !collector.thread_binding_.empty(); + bool skip_thread_partition = !has_thread_binding; + return LayoutInferencer::Substitute(std::move(f), skip_thread_partition); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/layout_reducer.cc b/tilelang/original/src/transform/layout_reducer.cc new file mode 100644 index 0000000000000000000000000000000000000000..957918c971b9fb30791065d01666c8cd6953303f --- /dev/null +++ b/tilelang/original/src/transform/layout_reducer.cc @@ -0,0 +1,410 @@ +/*! + * \file layout_reducer.cc + * + * Compute layout for local.reducer buffers and lower them to local.fragment. + */ + +#include +#include +#include +#include +#include +#include + +#include "../layout/layout.h" +#include "../op/fill.h" +#include "../op/finalize_reducer.h" +#include "../op/region.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "layout_reducer.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace tir::transform; +using arith::IRMutatorWithAnalyzer; + +/** + * @brief Construct a ReducerInfoNode from textual op and replication + * descriptors. + * + * Maps op_str to a ReducerOpType ("sum" → SUM, "max" → MAX, "min" → MIN) and + * rep_str to a ReducerRepType ("all" → ALL, "none" → NONE). + * + * @param op_str String identifying the reducer operation. + * @param rep_str String identifying the replication behavior. + * @throws RuntimeError if op_str or rep_str is not one of the supported values + * (triggers ICHECK). + */ +ReducerInfoNode::ReducerInfoNode(const String &op_str, const String &rep_str) { + if (op_str == "sum") + op = ReducerOpType::SUM; + else if (op_str == "max") + op = ReducerOpType::MAX; + else if (op_str == "min") + op = ReducerOpType::MIN; + else + ICHECK(false) << "Unrecognized reducer_info op: " << op_str; + + if (rep_str == "all") + rep = ReducerRepType::ALL; + else if (rep_str == "none") + rep = ReducerRepType::NONE; + else + ICHECK(false) << "Unrecognized reducer_info rep: " << rep_str; +} + +class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer { +public: +private: + /** + * @brief Visit an attribute statement and capture the IterVar for + * threadIdx.x. + * + * If the attribute key is `tir::attr::thread_extent` and the node is an + * `IterVar` whose `thread_tag` equals `"threadIdx.x"`, this sets the + * mutator's `thread_var_` to that IterVar (after asserting the iterator's + * extent is an `IntImm`). The previous `thread_var_` is preserved and + * restored after delegating to the base visitor. Delegates all traversal work + * to `IRMutatorWithAnalyzer::VisitStmt_`. + * + * Side effects: + * - Temporarily updates the member `thread_var_` during traversal of the + * child statement so subsequent visitors can read the thread index IterVar. + * + * @return The possibly mutated statement returned by the base visitor. + */ + Stmt VisitStmt_(const AttrStmtNode *op) final { + auto prev_thread_var = thread_var_; + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + auto result = IRMutatorWithAnalyzer::VisitStmt_(op); + thread_var_ = prev_thread_var; + return result; + } + + /** + * @brief Visits a TIR Block node to collect reducer metadata and apply + * discovered buffer layouts. + * + * This method: + * - Extracts reducer information from the block's `attr::kReducerInfo` + * annotation and populates the internal reducer_info_map_. + * - Registers allocated buffers by mapping each buffer's data Var to its + * Buffer in var_to_buffer_. + * - Recursively visits and rewrites the block body via the base mutator. + * - Merges any layouts accumulated in new_layout_map_ into the block's + * `attr::kLayoutMap` annotation (creating or extending the annotation), then + * clears new_layout_map_ for subsequent blocks. + * + * Side effects: + * - Updates reducer_info_map_, var_to_buffer_, and may set the block-level + * `kLayoutMap` annotation. + * - Clears new_layout_map_ after merging. + * + * @param op The Block node being visited. + * @return Stmt The potentially modified Block statement (as a Stmt). + */ + Stmt VisitStmt_(const BlockNode *op) final { + // Record annotations + if (op->annotations.count(attr::kReducerInfo)) { + auto map = op->annotations.Get(attr::kReducerInfo) + ->as>>(); + ICHECK(map) << "reducer_replication map is not defined"; + for (auto &&[var, rep] : map.value()) { + reducer_info_map_.Set( + var, ReducerInfo{rep.Get("op").value(), rep.Get("rep").value()}); + } + } + for (auto &&buffer : op->alloc_buffers) { + var_to_buffer_.Set(buffer->data, buffer); + } + auto result = IRMutatorWithAnalyzer::VisitStmt_(op).as().value(); + // After iterating over the body, set all layout_map to block + auto p_result = result.CopyOnWrite(); + auto layout_map = p_result->annotations.Get(attr::kLayoutMap) + ->as>() + .value_or(Map()); + for (auto &&[k, v] : new_layout_map_) + layout_map.Set(k, v); + if (!layout_map.empty()) + p_result->annotations.Set(attr::kLayoutMap, layout_map); + new_layout_map_.clear(); + return result; + } + + /** + * @brief Visit and possibly annotate a For node for reducer layout lowering. + * + * Visits a For node via the base mutator and, if the traversal is currently + * inside a reduction region (tracked by inside_reducer_range_) and this is + * the outermost loop of that region, annotates the loop with reducer + * information and derives per-buffer layout fragments for each reducer + * buffer. + * + * When annotating: + * - Sets the block-level `attr::kReducerInfo` annotation to the current + * inside_reducer_range_ map on the loop. + * - For each reducer buffer, reads the bound of `thread_var_` (requires the + * analyzer to have a const-int bound for it) and creates a Fragment: + * - If the reducer's replication type is ALL, creates a replication + * fragment across the thread extent. + * - If the replication type is NONE, builds a flattened index expression + * across buffer indices, reduces it modulo the thread extent, adds the + * thread minimum offset, and uses that as the fragment index. + * - Records the constructed Fragments into new_layout_map_ keyed by the + * buffer's data Var. + * + * Side effects: + * - May set `attr::kReducerInfo` on the For node's annotations. + * - Updates `new_layout_map_`. + * - Reads and relies on `thread_var_`, `analyzer_->const_int_bound`, and + * `var_to_buffer_`. + * + * Preconditions and checks: + * - `thread_var_` must be defined and have a constant-int bound when + * annotating. + * - Each reducer Var in inside_reducer_range_ must map to an allocated Buffer + * in var_to_buffer_ (ICHECK enforced). + * + * @param op The original For node being visited. + * @return The (possibly) transformed For statement. + */ + Stmt VisitStmt_(const ForNode *op) final { + // only annotate the outermost loop + bool should_annotate = false; + if (!inside_reducer_range_.empty() && !already_annotated_ && + op->kind == ForKind::kParallel) { + should_annotate = true; + already_annotated_ = true; + } + + auto opt_result = IRMutatorWithAnalyzer::VisitStmt_(op).as(); + ICHECK(opt_result); + auto result = opt_result.value(); + + if (should_annotate) { + // we are leaving the current loop nest. later ones may annotate again + already_annotated_ = false; + + auto p_result = result.CopyOnWrite(); + p_result->annotations.Set(attr::kReducerInfo, inside_reducer_range_); + + // Iterate over local.reducer.* buffers, append to reducer_op_map_, set + // layout by adding layout_map annotations, and convert scope to + // local.fragment + for (auto &&[reducer_var, info] : inside_reducer_range_) { + // analyze thread index bound, need to be inside WS section + ICHECK(thread_var_.defined()); + ICHECK(analyzer_->const_int_bound.IsBound(thread_var_->var)); + auto const_int_bound = analyzer_->const_int_bound(thread_var_); + int thread_min = const_int_bound->min_value; + int thread_extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + + auto opt_buffer = var_to_buffer_.Get(reducer_var); + ICHECK(opt_buffer); + const auto &buffer = opt_buffer.value(); + Fragment f; + if (info->rep == ReducerRepType::ALL) { + f = Fragment::FullyReplicated(buffer->shape, thread_extent); + } else if (info->rep == ReducerRepType::NONE) { + PrimExpr flatten_idx = InputPlaceholder(0); + for (int i = 1; i < buffer->shape.size(); ++i) + flatten_idx = flatten_idx * buffer->shape[i] + InputPlaceholder(i); + f = Fragment(buffer->shape, {}, + indexmod(flatten_idx, thread_extent) + thread_min, 1, + std::nullopt); + } + new_layout_map_.Set(buffer->data, f); + } + } + return result; + } + + /** + * @brief Handle BufferStore statements during IR mutation. + * + * This override is the visit hook for BufferStoreNode. Currently it delegates + * to the base IRMutatorWithAnalyzer implementation. Intended as the place to + * perform reducer-specific viability checks for stores (e.g., validating + * operations against reducer metadata); such checks are TODO and are not yet + * implemented. + * + * @return Stmt The (possibly transformed) statement returned by the base + * mutator. + */ + Stmt VisitStmt_(const BufferStoreNode *op) final { + //! TODO: check store viable according to info->op + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + /** + * @brief Processes Call expressions to track reducer ranges and finalize + * reducer operations. + * + * Visits call nodes, detects T.fill calls that target reducer buffers and + * records their reducer metadata in inside_reducer_range_ until the matching + * T.finalize_reducer is seen. When a FinalizeReducerOp call is encountered, + * this method appends the reducer operation enum value to the call arguments + * and removes the corresponding entry from inside_reducer_range_. + * + * Side effects: + * - Inserts and removes entries in inside_reducer_range_. + * - Mutates the FinalizeReducerOp call by pushing the reducer op enum as an + * extra argument. + * + * Failure modes: + * - ICHECK fails if a T.fill targets a reducer already recorded in + * inside_reducer_range_ (i.e., a prior T.fill without an intervening + * T.finalize_reducer). + * - ICHECK fails if T.finalize_reducer has no matching T.fill (no entry in + * inside_reducer_range_). + * + * @param op_ The CallNode being visited. + * @return PrimExpr The (possibly modified) call expression. + */ + PrimExpr VisitExpr_(const CallNode *op_) final { + auto op_ref = IRMutatorWithAnalyzer::VisitExpr_(op_).as().value(); + auto op = op_ref.CopyOnWrite(); + if (op->op.same_as(Fill::Get())) { + ICHECK(!op->args.empty()); + if (auto arg0_call = op->args[0].as()) { + // tl.region(...) — extract buffer var from its first arg + if (arg0_call.value()->op.same_as(RegionOp::Get())) { + ICHECK(!arg0_call.value()->args.empty()); + if (auto bl = arg0_call.value()->args[0].as()) { + Var var = bl->buffer->data; + if (reducer_info_map_.count(var)) { + ICHECK(inside_reducer_range_.count(var) == 0) + << "T.fill on reducer must be enclosed with a " + "T.finalize_reducer before next."; + inside_reducer_range_.Set(var, + reducer_info_map_.Get(var).value()); + } + } + } + // builtin.tvm_access_ptr(...) — existing path (legacy) + if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(arg0_call.value()->args.size() > 1); + if (auto var = arg0_call.value()->args[1].as(); + var && reducer_info_map_.count(var.value())) { + ICHECK(inside_reducer_range_.count(var.value()) == 0) + << "T.fill on reducer must be enclosed with a " + "T.finalize_reducer " + "before next."; + inside_reducer_range_.Set( + var.value(), reducer_info_map_.Get(var.value()).value()); + } + } + } else if (auto bl = op->args[0].as()) { + Var var = bl->buffer->data; + if (reducer_info_map_.count(var)) { + ICHECK(inside_reducer_range_.count(var) == 0) + << "T.fill on reducer must be enclosed with a T.finalize_reducer " + "before next."; + inside_reducer_range_.Set(var, reducer_info_map_.Get(var).value()); + } + } + } else if (op->op.same_as(FinalizeReducerOp::Get())) { + ICHECK(op->args.size() == 1); + Var var; + if (auto bl = op->args[0].as()) { + var = bl->buffer->data; + } else if (auto reg_call = op->args[0].as()) { + if (reg_call.value()->op.same_as(RegionOp::Get())) { + if (auto bl2 = reg_call.value()->args[0].as()) { + var = bl2->buffer->data; + } else { + LOG(FATAL) << "tl.region expects BufferLoad as first arg"; + } + } else { + var = GetVarFromAccessPtr(op->args[0]); + } + } else { + var = GetVarFromAccessPtr(op->args[0]); + } + ICHECK(inside_reducer_range_.count(var) == 1) + << "T.finalize_reducer must have a pairing T.fill ahead of it, " + "enclosing a reduction range."; + op->args.push_back((int)inside_reducer_range_.Get(var).value()->op); + inside_reducer_range_.erase(var); + } + return op_ref; + } + + /** + * @brief Construct a ReducerLayoutAnnotator with an arithmetic analyzer. + * + * Initializes the annotator's base IRMutatorWithAnalyzer with the provided + * arith::Analyzer, which the mutator uses to query symbolic bounds and + * simplify integer expressions during layout inference. + * + * @param analyzer Pointer to an arith::Analyzer used for symbolic analysis. + */ + ReducerLayoutAnnotator(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {} + + IterVar thread_var_; + Map reducer_info_map_; + Map inside_reducer_range_; + bool already_annotated_ = false; + Map var_to_buffer_; + Map new_layout_map_; + +public: + /** + * @brief Apply reducer layout substitution to a PrimFunc. + * + * Runs the ReducerLayoutAnnotator over the function body to collect reducer + * metadata, insert layout mappings for reducer buffers, and lower + * local.reducer usage to local.fragment-compatible forms. Returns a new + * PrimFunc whose body is the transformed IR. + * + * @param f The PrimFunc to transform; passed by value and returned with an + * updated body. + * @return PrimFunc The transformed PrimFunc with reducer layouts and related + * rewrites applied. + */ + static PrimFunc Substitute(PrimFunc f) { + arith::Analyzer analyzer; + ReducerLayoutAnnotator substituter(&analyzer); + PrimFuncNode *fptr = f.CopyOnWrite(); + fptr->body = substituter.VisitStmt(f->body); + return f; + } +}; + +/** + * @brief Create a TVM transform pass that lowers local.reducer buffers to + * local.fragment layouts. + * + * This pass runs ReducerLayoutAnnotator::Substitute on a PrimFunc to collect + * reducer metadata, compute per-buffer layout fragments for reducer buffers, + * and annotate blocks with the resulting layout map. It is exposed as a + * PrimFunc-level pass named "tl.LayoutReducer". + * + * @return tvm::transform::Pass A prim-function pass that applies the + * layout-reduction substitution. + */ +tvm::transform::Pass LayoutReducer() { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ReducerLayoutAnnotator::Substitute(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/layout_reducer.h b/tilelang/original/src/transform/layout_reducer.h new file mode 100644 index 0000000000000000000000000000000000000000..e46ade948ded6086d48d82eac9bfaad189acec4f --- /dev/null +++ b/tilelang/original/src/transform/layout_reducer.h @@ -0,0 +1,89 @@ +/*! + * \file layout_reducer.h + */ + +#ifndef TVM_TL_TRANSFORM_LAYOUT_REDUCER_H_ +#define TVM_TL_TRANSFORM_LAYOUT_REDUCER_H_ + +#include + +#include "../layout/layout.h" + +namespace tvm { +/** + * Types of reduction operations supported by TL transforms. + * + * SUM - arithmetic sum reduction. + * MAX - elementwise maximum reduction. + * MIN - elementwise minimum reduction. + */ + +/** + * Representation semantics for a reducer. + * + * ALL - reducer collapses all elements along the reduced axes. + * NONE - reducer does not collapse (used to represent a placeholder/no-op). + */ + +/** + * Holds metadata describing a reducer used in layout transforms. + * + * Contains the reduction operation (`op`) and its representation semantics + * (`rep`). + */ + +/** + * Construct a ReducerInfoNode from textual identifiers. + * + * @param op_str String identifier for the reduction operation (e.g., "sum", + * "max", "min"). + * @param rep_str String identifier for the representation semantics (e.g., + * "all", "none"). + */ + +/** + * Handle type for ReducerInfoNode (ObjectRef wrapper). + * + * Constructed from string identifiers for operation and representation. + * + * @param op_str String identifier for the reduction operation (e.g., "sum", + * "max", "min"). + * @param rep_str String identifier for the representation semantics (e.g., + * "all", "none"). + */ + +/** + * Attribute key used to attach ReducerInfo to IR nodes or other attribute maps. + */ +namespace tl { + +enum class ReducerOpType { SUM, MAX, MIN }; +enum class ReducerRepType { ALL, NONE }; + +struct ReducerInfoNode : Object { + ReducerOpType op; + ReducerRepType rep; + + ReducerInfoNode() = default; + ReducerInfoNode(const String &op_str, const String &rep_str); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReducerInfo", ReducerInfoNode, Object); +}; + +struct ReducerInfo : ObjectRef { +public: + TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) { + data_ = tvm::ffi::make_object(op_str, rep_str); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReducerInfo, ObjectRef, + ReducerInfoNode); +}; + +namespace attr { +constexpr const char *kReducerInfo = "reducer_info"; +} + +} // namespace tl +} // namespace tvm + +#endif diff --git a/tilelang/original/src/transform/legalize_negative_index.cc b/tilelang/original/src/transform/legalize_negative_index.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0df555ef5e66e84d227667200982efa9c6e0da6 --- /dev/null +++ b/tilelang/original/src/transform/legalize_negative_index.cc @@ -0,0 +1,239 @@ +/*! + * \file legalize_negative_index.cc + * \brief Legalize negative indices in buffer load/store expressions. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRVisitorWithAnalyzer; + +enum class IndexSignState { kNonNegative, kNegative, kUnknown }; + +using BufferAccessVariant = + std::variant; +using LoadStore2StateMap = + std::unordered_map>; + +class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { +public: + explicit NegativeIndexAnalyzer(LoadStore2StateMap *result) + : result_(result) {} + +private: + std::vector ProcessIdx(const ffi::Array &indices, + ffi::String buffer_name) { + std::vector states; + states.reserve(indices.size()); + + for (size_t i = 0; i < indices.size(); ++i) { + PrimExpr simplified = analyzer_.Simplify(indices[i]); + IndexSignState state = IndexSignState::kUnknown; + + // Handle scalar indices with the standard analyzer + if (simplified.dtype().lanes() == 1) { + if (analyzer_.CanProve(simplified >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(simplified < 0)) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; + } + // Vector indices: try to reason about non-negativity/negativity + // Common patterns are Ramp(base, stride, lanes) and Broadcast(value, + // lanes). + else if (const auto *ramp = simplified.as()) { + // Compute a safe lower/upper bound for the vector lanes + // lower_bound = base_min + min(0, stride_min) * (lanes - 1) + // upper_bound = base_max + max(0, stride_max) * (lanes - 1) + auto base_bound = analyzer_.const_int_bound(ramp->base); + auto stride_bound = analyzer_.const_int_bound(ramp->stride); + int lanes = *as_const_int(ramp->lanes); + + int64_t base_min = base_bound->min_value; + int64_t base_max = base_bound->max_value; + int64_t s_min = stride_bound->min_value; + int64_t s_max = stride_bound->max_value; + + // Guard against overflow is not strictly necessary here because + // bounds may be +/-inf represented by sentinel values. + int64_t lower = base_min; + if (s_min < 0) + lower += s_min * (lanes - 1); + int64_t upper = base_max; + if (s_max > 0) + upper += s_max * (lanes - 1); + + if (lower >= 0) + state = IndexSignState::kNonNegative; + else if (upper < 0) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; + } else if (const auto *broadcast = simplified.as()) { + auto v = analyzer_.Simplify(broadcast->value); + if (analyzer_.CanProve(v >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(v < 0)) + state = IndexSignState::kNegative; + else { + // Try const bound if proof unavailable + auto vb = analyzer_.const_int_bound(v); + if (vb->min_value >= 0) + state = IndexSignState::kNonNegative; + else if (vb->max_value < 0) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; + } + } + states.push_back(state); + } + + return std::move(states); + } + + bool NeedRecord(const std::vector &states) { + return std::any_of(states.begin(), states.end(), + [](const IndexSignState &state) { + return state == IndexSignState::kUnknown || + state == IndexSignState::kNegative; + }); + } + + void VisitExpr_(const BufferLoadNode *op) final { + std::vector states = + ProcessIdx(op->indices, op->buffer->name); + + if (NeedRecord(states)) + (*result_)[op] = std::move(states); + + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + std::vector states = + ProcessIdx(op->indices, op->buffer->name); + + if (NeedRecord(states)) + (*result_)[op] = std::move(states); + + IRVisitorWithAnalyzer::VisitStmt_(op); + } + +private: + LoadStore2StateMap *result_; +}; + +class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { +public: + static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) { + arith::Analyzer analyzer; + NegativeIndexRewriter rewriter(&analyzer, states); + PrimFuncNode *func_node = func.CopyOnWrite(); + func_node->body = rewriter.VisitStmt(func_node->body); + return func; + } + +private: + NegativeIndexRewriter(arith::Analyzer *analyzer, + const LoadStore2StateMap &states) + : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} + + ffi::Array UpdateIdx(const ffi::Array &indices, + const ffi::Array &buffer_shape, + const std::vector &state_vec) { + ICHECK_EQ(state_vec.size(), indices.size()) + << "State vector size mismatch for buffer load/store indices (" + << indices << ")"; + ffi::Array new_indices = indices; + for (size_t i = 0; i < indices.size(); ++i) { + if (state_vec[i] != IndexSignState::kNegative) + continue; + new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i])); + } + return new_indices; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = + Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); + + auto it = states_.find(op); + if (it == states_.end()) + return load; + + auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second); + return BufferLoad(load->buffer, indices, load->predicate); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = + Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + + auto it = states_.find(op); + if (it == states_.end()) + return store; + + auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second); + return BufferStore(store->buffer, store->value, indices, store->predicate); + } + +private: + const LoadStore2StateMap &states_; +}; + +PrimFunc LegalizeNegativeIndex(PrimFunc func) { + if (!func->body.defined()) { + return func; + } + + LoadStore2StateMap states; + NegativeIndexAnalyzer analyzer(&states); + analyzer(func->body); + if (states.empty()) { + return func; + } + + return NegativeIndexRewriter::Apply(std::move(func), states); +} + +tvm::transform::Pass LegalizeNegativeIndexPass() { + using namespace tir::transform; + auto pass_func = [](PrimFunc f, const IRModule &, PassContext) { + return LegalizeNegativeIndex(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex", + LegalizeNegativeIndexPass); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/legalize_safe_memory_access.cc b/tilelang/original/src/transform/legalize_safe_memory_access.cc new file mode 100644 index 0000000000000000000000000000000000000000..1a9da919c7c725be1e60ff20d8e528ef5faca77d --- /dev/null +++ b/tilelang/original/src/transform/legalize_safe_memory_access.cc @@ -0,0 +1,316 @@ +/*! + * \file legalize_safe_memory_access.cc + * \brief legalize safe memory access + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" +#include "../op/parallel.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "loop_partition.h" +#include "loop_vectorize.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; + +// GlobalMemChecker for a BufferLoad/BufferStore node: +// 1. Identify BufferLoad and BufferStore nodes. +// 2. Check if the buffer is in global scope. +// 3. For each index, compare against the buffer's shape. +// If the index might exceed the shape (upper bound too large), +// log a warning or handle accordingly. +struct GlobalMemChecker : public StmtExprVisitor { + + GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds) + : analyzer_(analyzer), + recursively_collect_conds_(recursively_collect_conds) {} + void VisitExpr_(const BufferLoadNode *op) final { + // Check if the buffer is in global scope + // This is because we are writing TilePrograms, where out of bounds + // accesses only happen in the global buffer. + if (IsGlobalBuffer(op->buffer)) { + CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true); + } + if (recursively_collect_conds_) { + StmtExprVisitor::VisitExpr_(op); + } + } + + void VisitStmt_(const BufferStoreNode *op) final { + // Check if the buffer is in global scope + if (IsGlobalBuffer(op->buffer)) { + CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false); + } + if (recursively_collect_conds_) { + StmtExprVisitor::VisitStmt_(op); + } + } + + // Helper function to determine if a buffer is global + bool IsGlobalBuffer(const Buffer &buffer) { + // The storage scope is often encoded in the buffer->data var name or + // associated attributes. In typical TVM IR, global buffers have scope + // "global". Here we assume a helper function GetPtrStorageScope is + // available. If not, you might need to parse buffer->data->name_hint or + // associated attributes. + String scope = buffer.scope(); + return scope == "global"; + } + + // Check each index against the buffer shape dimensions + void CheckBufferIndices(const Buffer &buffer, const Array &indices, + bool is_load) { + // Ensure indices count matches buffer dimension + if (indices.size() != buffer->shape.size()) { + LOG(WARNING) << "Buffer access dimension mismatch: indices size (" + << indices.size() << ") vs. shape size (" + << buffer->shape.size() << ")"; + return; + } + + for (size_t i = 0; i < indices.size(); i++) { + PrimExpr index = indices[i]; + PrimExpr shape_dim = buffer->shape[i]; + + bool is_index_constant = true; + PostOrderVisit(index, [&](const ObjectRef &obj) { + if (const VarNode *v = obj.as()) { + is_index_constant = false; + } + if (const BufferLoadNode *v = obj.as()) { + is_index_constant = false; + } + }); + if (is_index_constant) { + // If index is a constant, we can skip the check + continue; + } + + // We want to check if index < shape_dim can be proven. + // If analyzer->CanProve(index < shape_dim) returns false, + // it means we cannot prove the access is within bounds. + PrimExpr upper_bound_cond = index < shape_dim; + if (!analyzer_->CanProve(upper_bound_cond, + arith::ProofStrength::kSymbolicBound)) { + _conditions.push_back(upper_bound_cond); + } + // Check if index >= 0 can be proven. + PrimExpr lower_bound_cond = index >= 0; + if (!analyzer_->CanProve(lower_bound_cond, + arith::ProofStrength::kSymbolicBound)) { + _conditions.push_back(lower_bound_cond); + } + } + } + + Array GetConditions() { return _conditions; } + +private: + Array _conditions; + arith::Analyzer *analyzer_; + bool recursively_collect_conds_; +}; + +class SafeMemorysRewriter : public IRMutatorWithAnalyzer { +public: + // Static method to substitute and transform the given PrimFunc + static PrimFunc Substitute(PrimFunc f) { + arith::Analyzer analyzer; + // Create an instance of the legalizer with the analyzer + SafeMemorysRewriter substituter(&analyzer); + // Get a mutable copy of the function node + PrimFuncNode *fptr = f.CopyOnWrite(); + for (const auto &[_, buffer] : f->buffer_map) { + substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + // Apply the legalizer to the function body + fptr->body = substituter.VisitStmt(f->body); + return f; + } + +private: + // Constructor initializing the base class with the analyzer + SafeMemorysRewriter(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} + // Constructor initializing the base class with the analyzer + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + + // For Load/Store, we only check the current node, not its children. + // Since rewriter will recursively visit children. + GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); + checker(load); + Array conditions = checker.GetConditions(); + + if (conditions.empty()) { + return load; + } + + // For loading, we can always use safe value if the access is out of + // bounds + PrimExpr value = load; + for (auto cond : conditions) { + ICHECK(cond.dtype() == DataType::Bool(1)) + << "condition is not a boolean: " << cond; + value = if_then_else(cond, value, GetSafeValue(load->buffer)); + } + return value; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + // Check if the buffer is in global scope + auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + + GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); + checker(store); + Array conditions = checker.GetConditions(); + + // Skip boundary check if the store value is an IfThenElse + if (const IfThenElseNode *if_node = store->value.as()) { + if (!conditions.empty()) { + LOG(WARNING) + << "Skipping boundary check for store with IfThenElse value: " + << store->value + << "\nAs manual boundary check detected, potential out-of-bounds " + "access may occur." + << "\nAuto detect boundaries are " << conditions; + return store; + } + return store; + } + + if (conditions.empty()) { + return store; + } + + // If a store is out of bounds, we skip the corresponding stmt directly. + Stmt store_with_conditions = store; + for (auto cond : conditions) { + store_with_conditions = IfThenElse(cond, store_with_conditions); + } + return store_with_conditions; + } + + // Recursively check Load/Store in the call arguments. + // For example + // T.call_extern("handle", "atomicAddx2", T.address_of(C), + // T.address_of(C_shared)) + + // NOTE(chaofan): This is currently not the most rigorous solution. + // The check here is primarily intended to handle extern functions like + // atomicAdd, which may involve memory access. Due to their special nature, + // the BufferLoad in their parameters might be used for boundary checks of the + // current statement. The current solution adopts a simplified approach: + // directly applying the boundary constraints of all parameters to the + // statement. While not entirely precise, it addresses most common scenarios. + Stmt VisitStmt_(const EvaluateNode *op) final { + auto evaluate = Downcast(op); + + if (const CallNode *call_op = op->value.as()) { + auto call = Downcast(op->value); + if (call->op == builtin::call_extern()) { + // For CallExtern, we recursively collect conditions from all children. + // Since we cannot rewrite any BufferLoad in its children (Rewrite will + // cause potential Nullptr exception). + GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true); + checker(call); + Array conditions = checker.GetConditions(); + + if (conditions.empty()) { + return evaluate; + } + + Stmt evaluate_with_conditions = evaluate; + for (auto cond : conditions) { + evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions); + } + return evaluate_with_conditions; + } + } + + return evaluate; + } + + Stmt VisitStmt_(const BlockNode *op) final { + for (auto buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + if (op->annotations.count(attr::kSafeValueMap)) { + auto map = op->annotations.Get(attr::kSafeValueMap) + ->as>() + .value(); + for (const auto &[var, safe_value] : map) { + ICHECK(buffer_data_to_buffer_.count(var)) + << "buffer " << var << " is not found in the block " + << buffer_data_to_buffer_; + auto buffer = buffer_data_to_buffer_[var]; + annotated_safe_value_map_.Set(buffer, safe_value); + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + bool IsLocalBuffer(const Buffer &buffer) { + String scope = buffer.scope(); + return scope == "local" || scope == "local.fragment" || + scope == "local.var"; + } + + bool isSharedBuffer(const Buffer &buffer) { + String scope = buffer.scope(); + return scope == "shared" || scope == "shared.dyn"; + } + + bool IsGlobalBuffer(const Buffer &buffer) { + String scope = buffer.scope(); + return scope == "global"; + } + // Get the safe value of the buffer + PrimExpr GetSafeValue(const Buffer &buffer) { + if (annotated_safe_value_map_.count(buffer)) { + return annotated_safe_value_map_[buffer]; + } + return make_zero(buffer->dtype); + } + + Map buffer_data_to_buffer_; + Map annotated_safe_value_map_; +}; + +// Create a pass that legalizes vectorized loops in the IRModule +tvm::transform::Pass LegalizeSafeMemoryAccess() { + using namespace tir::transform; + // Define the transformation function to be applied + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool disable_safe_memory_legalize = + ctx->GetConfig(kDisableSafeMemoryLegalize, Bool(false)).value(); + if (disable_safe_memory_legalize) { + return f; + } + return SafeMemorysRewriter::Substitute(std::move(f)); + }; + // Create and return a PrimFunc pass with the transformation function + return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {}); +} + +// Register the pass globally so it can be used in the compilation pipeline +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess", + LegalizeSafeMemoryAccess); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/legalize_vectorized_loop.cc b/tilelang/original/src/transform/legalize_vectorized_loop.cc new file mode 100644 index 0000000000000000000000000000000000000000..4fd4ab91f6baf1fceb67286f9c66c48f2803c424 --- /dev/null +++ b/tilelang/original/src/transform/legalize_vectorized_loop.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file layout_inference.cc + * \brief infer the fragment/shared memory layout + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/parallel.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "loop_partition.h" +#include "loop_vectorize.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; + +// Class to legalize vectorized loops by transforming them appropriately +class LoopVectorizedLegalizer : IRMutatorWithAnalyzer { +public: + // Static method to substitute and transform the given PrimFunc + static PrimFunc Substitute(PrimFunc f) { + arith::Analyzer analyzer; + // Create an instance of the legalizer with the analyzer + LoopVectorizedLegalizer substituter(&analyzer); + // Get a mutable copy of the function node + PrimFuncNode *fptr = f.CopyOnWrite(); + // Apply the legalizer to the function body + fptr->body = substituter.VisitStmt(f->body); + return f; + } + +private: + // Constructor initializing the base class with the analyzer + LoopVectorizedLegalizer(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} + + // Override the VisitStmt_ method to handle ForNode (loop statements) + Stmt VisitStmt_(const ForNode *op) final { + // Visit and potentially modify the loop node + For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + // If the loop is not vectorized, proceed with the default behavior + if (for_node->kind != ForKind::kVectorized) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + // Change the loop kind from vectorized to serial + for_node.CopyOnWrite()->kind = ForKind::kSerial; + // Apply vectorization transformation to the loop + return VectorizeLoop(for_node, analyzer_); + } +}; + +// Create a pass that legalizes vectorized loops in the IRModule +tvm::transform::Pass LegalizeVectorizedLoop() { + using namespace tir::transform; + // Define the transformation function to be applied + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return LoopVectorizedLegalizer::Substitute(std::move(f)); + }; + // Create and return a PrimFunc pass with the transformation function + return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeVectorizedLoop", {}); +} + +// Register the pass globally so it can be used in the compilation pipeline +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop", + LegalizeVectorizedLoop); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/loop_partition.cc b/tilelang/original/src/transform/loop_partition.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4236c6dbcf5b70933367e0afb5f6d7dd535d788 --- /dev/null +++ b/tilelang/original/src/transform/loop_partition.cc @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file loop_partition.cc + * \brief Partition parallel loops onto threads + */ + +#include "loop_partition.h" + +#include + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class BufferIndiceSimplify : public StmtExprMutator { +public: + BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {} + +private: + PrimExpr VisitExpr_(const BufferLoadNode *node) final { + auto visited = StmtExprMutator::VisitExpr_(node); + auto n = Downcast(visited); + auto nptr = n.CopyOnWrite(); + nptr->indices = nptr->indices.Map( + [&](const auto &e) { return analyzer_->Simplify(e); }); + return n; + } + Stmt VisitStmt_(const BufferStoreNode *node) final { + auto visited = StmtExprMutator::VisitStmt_(node); + auto n = Downcast(visited); + auto nptr = n.CopyOnWrite(); + nptr->indices = nptr->indices.Map( + [&](const auto &e) { return analyzer_->Simplify(e); }); + return n; + } + arith::Analyzer *analyzer_; +}; + +// Rewrite the parallel loop into a common loop, which is mapped to threads +For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, + const Fragment &loop_layout) { + ICHECK(loop_layout.defined()); + ICHECK(thread_var.defined()); + int old_loop_depth = loop_layout->InputDim(); + int new_loop_depth = loop_layout->OutputDim(); + // Create the new loop iter var + Array vars; + for (int i = 0; i < new_loop_depth; i++) { + Var var = Var(std::string{char('i' + i)}); + analyzer->Bind(var, Range::FromMinExtent(make_zero(var->dtype), + loop_layout->OutputShape()[i])); + vars.push_back(var); + } + vars.push_back(thread_var); + // create the substitute map, and the loop body + Map vmap; + Stmt body = std::move(op); + Array loop_mins; + Array loop_extents; + auto inverse_info = loop_layout->InverseWithLevel(); + auto inv_loop = inverse_info.first; + // Must check the guard if the layout can not be proved as bijective + bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; + auto indices = inv_loop->Forward(Array(vars.begin(), vars.end())); + // Normalize thread var once so we can reuse the same substitution later. + Map thread_offset_map; + bool has_thread_offset = false; + if (loop_layout->ThreadRange().defined()) { + auto range = loop_layout->ThreadRange(); + thread_offset_map.Set(thread_var, thread_var - range->min); + has_thread_offset = true; + } + for (int i = 0; i < old_loop_depth; i++) { + const ForNode *loop = body.as(); + ICHECK(loop != nullptr) + << "No extra statements are allowed between nested parallel loops."; + vmap.Set(loop->loop_var, indices[i]); + loop_mins.push_back(loop->min); + loop_extents.push_back(loop->extent); + body = loop->body; + } + // substitute and re-construct the serial loop + body = Substitute(body, vmap); + // Guard executes the recovered loop body only if each inverse-mapped iterator + // falls back into the original For ranges. We first check every axis from the + // old loop nest (old_loop_depth) and then the extra index produced by inverse + // layouts that carry a replicate/thread component (`inv_output_shape`). Both + // must stay within bounds to ensure correctness. Example: layout([i, j]) = + // floor((i * 16 + j) / 32) may generate extra points when the new loop + // enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j) + // or replicate index fall outside their original extents. + // Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points + // when the new loop enumerates 0..31; this guard skips iterations where the + // inverse i, j land outside the original extents. This protects + // non-surjective loop_layout mappings that otherwise over-cover the parallel + // space. + PrimExpr guard = const_true(); + + if (need_guard) { + for (int i = 0; i < old_loop_depth; i++) { + PrimExpr index = indices[i]; + if (has_thread_offset) { + index = Substitute(index, thread_offset_map); + } + PrimExpr lower_bound = analyzer->Simplify(index >= loop_mins[i]); + PrimExpr upper_bound = + analyzer->Simplify(index < loop_mins[i] + loop_extents[i]); + guard = And(guard, And(lower_bound, upper_bound)); + } + auto inv_output_shape = inv_loop->OutputShape(); + if (inv_output_shape.size() > static_cast(old_loop_depth)) { + PrimExpr replicate_index = indices[old_loop_depth]; + if (has_thread_offset) { + replicate_index = Substitute(replicate_index, thread_offset_map); + } + PrimExpr replicate_extent = inv_output_shape[old_loop_depth]; + PrimExpr lower_bound = analyzer->Simplify( + replicate_index >= make_zero(replicate_index.dtype())); + PrimExpr upper_bound = + analyzer->Simplify(replicate_index < replicate_extent); + guard = And(guard, And(lower_bound, upper_bound)); + } + PrimExpr simplified_guard = analyzer->Simplify(guard); + if (!analyzer->CanProve(simplified_guard)) { + body = IfThenElse(simplified_guard, body, Stmt()); + } + } + + for (int i = new_loop_depth - 1; i >= 0; i--) { + body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i], + ForKind::kSerial, body); + analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i])); + } + + body = BufferIndiceSimplify(analyzer)(body); + + if (has_thread_offset) { + body = Substitute(body, thread_offset_map); + } + + auto for_node = LoopPragmaUnroll(Downcast(body)); + return for_node; +} + +class LoopPramaUnroller : public StmtExprMutator { +public: + LoopPramaUnroller() = default; + +private: + Stmt VisitStmt_(const ForNode *node) final { + if (node->kind == ForKind::kSerial) { + auto analyzer = std::make_shared(); + if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) { + return StmtExprMutator::VisitStmt_(node); + } + For new_for = tvm::ffi::GetRef(node); + auto for_ptr = new_for.CopyOnWrite(); + for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false)); + for_ptr->kind = ForKind::kUnrolled; + return new_for; + } + return StmtExprMutator::VisitStmt_(node); + } +}; + +class LoopPartitioner : public StmtExprVisitor { +public: + LoopPartitioner() = default; + + Fragment Partition(const For &op, int num_thread, int vectorize_size) { + this->VisitStmt(op); + DataType dtype = DataType::Int(32); + if (!loop_vars_.empty()) { + dtype = loop_vars_.back()->var.dtype(); + } + PrimExpr flattened = make_const(dtype, 0); + PrimExpr vector_extent = make_const(dtype, vectorize_size); + PrimExpr thread_extent_const = make_const(dtype, num_thread); + for (size_t i = 0; i < loop_vars_.size(); i++) { + PrimExpr extent = loop_vars_[i]->dom->extent; + flattened = flattened * extent + loop_vars_[i]->var; + } + PrimExpr access_idx = FloorDiv(flattened, vector_extent); + PrimExpr thd = FloorMod(access_idx, thread_extent_const); + PrimExpr idx = FloorDiv(access_idx, thread_extent_const) * vector_extent + + FloorMod(flattened, vector_extent); + + auto fragment = Fragment(loop_vars_, {idx}, {thd}, {}); + if (has_fragment_) { + // for fragment buffer, we don't need to replicate the loop layout + auto thread_extent = *as_const_int(fragment->ThreadExtent()); + auto num_thread_fragment = num_thread / thread_extent; + fragment = fragment->Replicate(num_thread_fragment); + } + return fragment; + } + +private: + void VisitExpr_(const BufferLoadNode *op) final { + if (op->buffer.scope() == "local.fragment") { + has_fragment_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.scope() == "local.fragment") { + has_fragment_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const ForNode *node) final { + if (node->kind == ForKind::kParallel) { + body_ = node->body; + loop_vars_.push_back( + IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var, + IterVarType::kDataPar)); + } + StmtExprVisitor::VisitStmt_(node); + } + + Stmt body_; + PrimExpr flattened = 0; + bool has_fragment_ = false; + Array loop_vars_; +}; + +Fragment PlanLoopPartition(const For &op, size_t num_thread, + int vectorize_size) { + LoopPartitioner partitioner; + return partitioner.Partition(op, num_thread, vectorize_size); +} + +Fragment PlanLoopPartition(const For &op, int vectorize_size, + const Range &thread_range) { + size_t num_thread = *as_const_int(thread_range->extent); + LoopPartitioner partitioner; + Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size); + return fragment->BindThreadRange(thread_range); +} + +For LoopPragmaUnroll(For stmt) { + LoopPramaUnroller unroller; + For unrolled = Downcast(unroller(std::move(stmt))); + return unrolled; +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/loop_partition.h b/tilelang/original/src/transform/loop_partition.h new file mode 100644 index 0000000000000000000000000000000000000000..1103e7515b400873c037bd75b706498bfe0781f9 --- /dev/null +++ b/tilelang/original/src/transform/loop_partition.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file loop_partition.h + * \brief Partition parallel loops onto threads + */ + +#ifndef TVM_TL_LOOP_PARTITION_H_ +#define TVM_TL_LOOP_PARTITION_H_ + +#include + +#include "../layout/layout.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, + const Fragment &loop_layout); + +Fragment PlanLoopPartition(const For &op, size_t num_thread, + int vectorize_size); + +Fragment PlanLoopPartition(const For &op, int vectorize_size, + const Range &thread_range); + +For LoopPragmaUnroll(For stmt); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_LOOP_PARTITION_H_ diff --git a/tilelang/original/src/transform/loop_vectorize.cc b/tilelang/original/src/transform/loop_vectorize.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a446731f173854eaa1d661a7621699be883ef81 --- /dev/null +++ b/tilelang/original/src/transform/loop_vectorize.cc @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file loop_vectorize.cc + * \brief A tool to automatically vectorize a for loop + */ + +#include "loop_vectorize.h" +#include "../op/builtin.h" +#include "../target/utils.h" +#include "arith/int_operator.h" +#include "arith/ir_visitor_with_analyzer.h" +#include "common/loop_vectorization_utils.h" +#include "tvm/tir/analysis.h" +#include "tvm/tir/var.h" +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +struct VectorizePlanResult { + int vector_size; + bool dynamic; + PrimExpr condition; +}; + +class VectorizeFindGlobalAccess : public StmtExprVisitor { +public: + VectorizeFindGlobalAccess() = default; + + bool HasGlobalAccess(const Stmt &stmt) { + this->operator()(stmt); + return has_global_access_; + } + +private: + bool has_global_access_ = false; + + void VisitStmt_(const BufferStoreNode *node) final { + if (node->buffer.scope() == "global") + has_global_access_ = true; + return StmtExprVisitor::VisitStmt_(node); + } + + void VisitExpr_(const BufferLoadNode *node) final { + if (node->buffer.scope() == "global") + has_global_access_ = true; + return StmtExprVisitor::VisitExpr_(node); + } +}; + +class VectorizePlanner : public arith::IRMutatorWithAnalyzer { +public: + explicit VectorizePlanner(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} + + int Plan(const For &node) { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + Optional opt_disable_vectorize_256 = + ctxt->GetConfig(kDisableVectorize256, Optional()); + bool disable_vectorize_256 = + opt_disable_vectorize_256.value_or(Bool(false)); + if (tvm::tl::TargetIsSm100(Target::Current(false)) && + !disable_vectorize_256 && + VectorizeFindGlobalAccess().HasGlobalAccess(node)) { + vector_load_bits_max_ = vector_size_ = 256; + } else { + vector_load_bits_max_ = vector_size_ = 128; + } + this->operator()(node); + return vector_size_; + } + +private: + Stmt VisitStmt_(const ForNode *node) final { + inner_for_ = node; + bool contains_nested_for = false; + // Must analysis vectorization on the innermost loop + PostOrderVisit(Downcast(node->body), [&](const ObjectRef &obj) { + if (obj.as()) { + contains_nested_for = true; + } + }); + + if (!contains_nested_for) { + auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent)); + // Here I disable dynamic shape completely, + // In order to do it, the Planner should accept an analyzer with + // arithmetic info outside to prove the dividiblity of vector size + if (!extent_ptr) { + vector_size_ = 1; + return ffi::GetRef(node); + } + vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); + } + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *node) final { + if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || + node->buffer.scope() == "shared.dyn") + has_nonlocal_memory_access_ = true; + if (node->buffer->shape.size() == 1) { + // TODO(lei): This should be improved as + // constant buffer that tl hack to use as local register. + auto boundary_check = node->buffer->shape[0].as(); + if (boundary_check && boundary_check->value == 1) { + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } + } + UpdateVectorSize(node->indices, node->buffer); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } + + Stmt VisitStmt_(const BufferStoreNode *node) final { + if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || + node->buffer.scope() == "shared.dyn") + has_nonlocal_memory_access_ = true; + UpdateVectorSize(node->indices, node->buffer); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); + } + + Stmt VisitStmt_(const IfThenElseNode *node) final { + CheckConditionVectorized(node->condition); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); + } + + PrimExpr VisitExpr_(const CallNode *node) final { + if (node->op == builtin::if_then_else()) { + CheckConditionVectorized(node->args[0]); + } else if (node->op == builtin::call_extern()) { + // do not vectorize extern calls + vector_size_ = 1; + } else if (node->op.same_as(tl::rng_rand()) || + node->op.same_as(tl::rng_init())) { + // do not vectorize random operation + vector_size_ = 1; + } + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } + + void CheckConditionVectorized(const PrimExpr &cond) { + // TODO: perform some checks here + } + + PrimExpr VisitExpr_(const CastNode *node) final { + vector_size_ = arith::ZeroAwareGCD( + vector_load_bits_max_ / node->dtype.bits(), vector_size_); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } + + void UpdateVectorSize(const Array indices, const Buffer &buffer) { + if (!inner_for_) + return; + // 1. Compute raw element offset + auto strides = buffer->strides; + if (buffer->strides.empty()) { + PrimExpr stride = 1; + for (int i = indices.size() - 1; i >= 0; --i) { + strides.push_back(stride); + stride = stride * buffer->shape[i]; + } + strides = Array{strides.rbegin(), strides.rend()}; + } + PrimExpr elem_offset = 0; + for (int i = 0; i < indices.size(); ++i) { + elem_offset += indices[i] * strides[i]; + } + // 2. If element offset is independent with loop_var, ignore it + if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) { + return; + } + // 3. Check if current vector_size_ works with invariant boundary check + if (!IsExprInvariantInVectorBoundary(elem_offset, inner_for_->loop_var, + vector_size_, analyzer_)) { + // If not, tight vectorize bound with buffer dtype constraint + vector_size_ = arith::ZeroAwareGCD( + vector_size_, vector_load_bits_max_ / + (buffer->dtype.bits() * buffer->dtype.lanes())); + } + // 4. Try to vectorize buffer load + while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, vector_size_, analyzer_)) { + vector_size_ /= 2; + } + } + + int vector_load_bits_max_; + + const ForNode *inner_for_{}; + bool has_nonlocal_memory_access_ = false; + int vector_size_ = 128; +}; + +class VectorizeRewriter : public StmtExprMutator { +public: + VectorizeRewriter(int vector_size) : vector_size_(vector_size) {} + +private: + Stmt VisitStmt_(const ForNode *node) final { + inner_for_ = node; + auto ret = StmtExprMutator::VisitStmt_(node); + if (inner_for_ == node) { // rewrite the innermost loop + For fnode = ret.as().value(); + auto old_var = fnode->loop_var; + auto extent_ptr = as_const_int(fnode->extent); + ICHECK(extent_ptr) << fnode->extent; + int extent = *extent_ptr; + ICHECK(extent % vector_size_ == 0) + << "extent: " << extent << " vector_size_: " << vector_size_; + ICHECK(is_zero(fnode->min)); + if (extent == vector_size_) { + fnode.CopyOnWrite()->kind = ForKind::kVectorized; + return fnode; + } else { + Var inner_var = Var("vec"); + Var outer_var = Var(old_var->name_hint); + Map vmap; + vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); + Stmt body = Substitute(fnode->body, vmap); + body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); + body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, + fnode->thread_binding, fnode->annotations, fnode->step, + fnode->span); + return body; + } + } else { + return ret; + } + } + + const ForNode *inner_for_{}; + const int vector_size_; +}; + +int GetVectorizeSize(const For &loop) { + arith::Analyzer analyzer; + return VectorizePlanner(&analyzer).Plan(loop); +} + +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) { + return VectorizePlanner(analyzer).Plan(loop); +} + +bool CanProveIndependent(const PrimExpr &expr, Var var, + arith::Analyzer *analyzer) { + // 1. if var doesn't exist, it is independent + bool used_var = UsesVar(expr, [&](const VarNode *v) { + return tvm::ffi::GetRef(v).same_as(var); + }); + if (!used_var) { + return true; + } + // 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v + Var var_1("_t", var.dtype()); + auto expr_1 = Substitute(expr, {{var, var_1}}); + if (analyzer->CanProveEqual(expr, expr_1)) { + return true; + } + return false; +} + +bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var, + int target_vectorized_size, + arith::Analyzer *analyzer) { + // Check if expr is invariant within vector boundaries + // We're trying to prove the access expression A[f(var)] depends only on + // floor(var/vecsize), not on var%vecsize + // Mathematically: + // \forall var, f(floor(var/vecsize)*vecsize + var%vecsize) == + // f(floor(var/vecsize)*vecsize + 0) + // Example: for i in T.vectorized(8): + // A[i] = B[i] * C[i//4] + // if vecsize=4, f(i)=i//4 depends only on i//4 + // Therefore A[i] = B[i] * C[i//4] can be vectorized with vecsize=4 + PrimExpr var_aligned = + floordiv(var, target_vectorized_size) * target_vectorized_size; + PrimExpr expr_aligned = Substitute(expr, {{var, var_aligned}}); + if (analyzer->CanProveEqual(expr, expr_aligned)) { + return true; + } + return false; +} + +bool IndiceCanVectorize(const PrimExpr &expr, Var var, + const PrimExpr &iter_var_size, + int target_vectorized_size, arith::Analyzer *analyzer) { + ICHECK(target_vectorized_size >= 1); + if (target_vectorized_size == 1) + return true; + + // Extent must be divisible + PrimExpr target_size_for_iter = + make_const(iter_var_size.dtype(), target_vectorized_size); + PrimExpr target_size_for_expr = + make_const(expr.dtype(), target_vectorized_size); + PrimExpr target_size_for_var = + make_const(var.dtype(), target_vectorized_size); + PrimExpr zero = make_const(var.dtype(), 0); + + if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), + 0)) + return false; + + if (IsExprInvariantInVectorBoundary(expr, var, target_vectorized_size, + analyzer)) { + return true; + } + + auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}})); + // The base offset must be divisible + if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr), + zero)) { + return false; + } + + // Bind thread range + Var v0("v0", var.dtype()), v1("v1", var.dtype()); + analyzer->Bind(v0, Range(zero, target_size_for_var)); + analyzer->Bind(v1, Range(zero, analyzer->Simplify(FloorDiv( + iter_var_size, target_size_for_iter)))); + PrimExpr expr_transformed = analyzer->Simplify( + Substitute(expr, {{var, v0 + v1 * target_size_for_var}})); + Vectorizer vectorizer(v0, target_size_for_var); + PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); + + // This simplify is necessary for thread region specified + // optimizations. + expr_vectorized = analyzer->Simplify(expr_vectorized); + auto ramp_node = expr_vectorized.as(); + if (!ramp_node) { + // Broadcast value + if (expr_vectorized.dtype().lanes() == 1) + return true; + else + return false; + } else { + return is_one(ramp_node->stride); + } +} + +For VectorizeLoop(const For &loop, int vectorize_hint) { + if (vectorize_hint <= 0) { + arith::Analyzer analyzer; + VectorizePlanner planner(&analyzer); + vectorize_hint = planner.Plan(loop); + } + if (vectorize_hint == 1) + return loop; + auto rewriter = VectorizeRewriter(vectorize_hint); + return Downcast(rewriter(loop)); +} + +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint) { + if (vectorize_hint <= 0) { + VectorizePlanner planner(analyzer); + vectorize_hint = planner.Plan(loop); + } + if (vectorize_hint == 1) + return loop; + auto rewriter = VectorizeRewriter(vectorize_hint); + return Downcast(rewriter(loop)); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/loop_vectorize.h b/tilelang/original/src/transform/loop_vectorize.h new file mode 100644 index 0000000000000000000000000000000000000000..92a756228dc1e3474328ba5d6188bc70e37b639e --- /dev/null +++ b/tilelang/original/src/transform/loop_vectorize.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file loop_vectorize.h + * \brief A tool to automatically vectorize a for loop + */ + +#ifndef TVM_TL_LOOP_VECTORIZE_H_ +#define TVM_TL_LOOP_VECTORIZE_H_ + +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +int GetVectorizeSize(const For &loop); + +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer); + +For VectorizeLoop(const For &loop, int vectorize_hint = -1); + +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint = -1); + +// Can prove expr is independent with var, i.e. the value of expr doesn't change +// when var changes +bool CanProveIndependent(const PrimExpr &expr, Var var, + arith::Analyzer *analyzer); + +// Check if expr is invariant within vector boundaries +bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var, + int target_vectorized_size, + arith::Analyzer *analyzer); + +bool IndiceCanVectorize(const PrimExpr &expr, Var var, + const PrimExpr &iter_var_size, + int target_vectorized_size, arith::Analyzer *analyzer); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_LOOP_VECTORIZE_H_ diff --git a/tilelang/original/src/transform/lower_device_kernel_launch.cc b/tilelang/original/src/transform/lower_device_kernel_launch.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2d8ae239ab812fbff10df05bc668dd5ce39127e --- /dev/null +++ b/tilelang/original/src/transform/lower_device_kernel_launch.cc @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_device_kernel_launch.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace ffi; +namespace { +struct KernelInfo { + // The device on which the PrimFunc runs + Target target; + + // The externally visible symbol which may refer to the PrimFunc + // when launching a device kernel. + String global_symbol; + + // The parameters accepted by the PrimFunc. Used to rewrite + // `launch_args` to be in terms of the calling scope. + Array params; + + // The launch parameters that should annotate the PrimFunc, if the + // kernel is ever called from the host. + Array launch_params; + + // Additional arguments which must be provided to the host-side + // PackedFunc. These may be in terms of the function's parameters + // (e.g. a function that computes the average of `N` elements, and + // which must be launched with `N` CUDA threads). + Array launch_args; + + // The extent of each thread + Map thread_extent; + // The amount of dynamic shared memory used + Optional dyn_shmem_size{std::nullopt}; +}; + +/*! + * \brief Visitor class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { +public: + static KernelInfo Collect(const GlobalVar &gvar, const PrimFunc &func) { + DeviceInfoCollector collector; + collector.info_.target = + func->GetAttr(tvm::attr::kTarget).value().WithoutHost(); + collector.info_.params = func->params; + + collector(func->body); + + // The dynamic shared memory is required to be the last of the + // kernel launch parameters + if (collector.dyn_shmem_size) { + collector.info_.launch_params.push_back( + tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); + } + + collector.info_.global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol) + .value_or(gvar->name_hint); + + collector.info_.launch_args = collector.info_.launch_params.Map( + [&](const auto ¶m) { return collector.GetArgument(param); }); + collector.info_.dyn_shmem_size = collector.dyn_shmem_size; + collector.info_.thread_extent = collector.thread_extent; + return collector.info_; + } + +private: + PrimExpr GetArgument(const String &launch_param) const { + if (launch_param == + tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { + CHECK(dyn_shmem_size.defined()) + << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc did not contain Allocate node with shared " + "dynamic scope."; + return dyn_shmem_size.value(); + } + + auto extent = thread_extent.Get(launch_param); + CHECK(extent) << "Compute kernel requires launch parameter \"" + << launch_param + << "\", but PrimFunc does not contain AttrStmt \"" + << tir::attr::thread_extent + << "\" defining this thread extent"; + return extent.value(); + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!defined_thread.count(iv.get())) { + defined_thread.insert(iv.get()); + info_.launch_params.push_back(iv->thread_tag); + thread_extent.Set(iv->thread_tag, op->value); + } + } + + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocateNode *op) final { + auto storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn") { + ICHECK(!dyn_shmem_size.defined()) + << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + + PrimExpr dyn_size = Integer(1); + for (const auto &extent : op->extents) { + dyn_size *= extent; + } + dyn_size *= op->dtype.bytes() * op->dtype.lanes(); + + dyn_shmem_size = dyn_size; + } + StmtVisitor::VisitStmt_(op); + } + + // The collected results + KernelInfo info_; + // recording what thread axis have been visited. + std::unordered_set defined_thread; + // The extent of each thread + Map thread_extent; + // The amount of dynamic shared memory used + Optional dyn_shmem_size{std::nullopt}; +}; + +class ReturnRemover : public StmtExprMutator { +public: + static Stmt Apply(const Stmt &stmt) { + ReturnRemover mutator; + return mutator(stmt); + } + +private: + using Parent = StmtExprMutator; + Stmt VisitStmt_(const EvaluateNode *op) override { + if (auto *call = op->value.as()) { + if (call->op.same_as(builtin::ret())) { + ICHECK_EQ(call->args.size(), 1); + auto as_int = call->args[0].as(); + ICHECK(as_int && as_int->value == 0) + << "Device kernel may only contain successful return, T.ret(0)"; + return Evaluate(0); + } + } + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) override { + if (op->op.same_as(builtin::ret())) { + LOG(FATAL) << "Call to builtin::ret() should only appear within an " + "Evaluate node"; + } + return Parent::VisitExpr_(op); + } +}; +} // namespace + +class DeviceKernelMutator : public StmtExprMutator { +public: + using Parent = StmtExprMutator; + + explicit DeviceKernelMutator( + std::unordered_map device_info_map) + : device_info_map_(std::move(device_info_map)) {} + + PrimFunc RewriteKernelLaunchSite(const GlobalVar &gvar, PrimFunc func) { + ICHECK(!current_target_.defined()); + auto it = device_info_map_.find(gvar.get()); + ICHECK(it != device_info_map_.end()); + current_target_ = it->second.target; + + auto body = VisitStmt(func->body); + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + + current_target_ = std::nullopt; + return func; + } + + PrimFunc UpdateKernelAttributes(const GlobalVar &gvar, PrimFunc func) const { + bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); + bool is_call_extern = extern_function_call_.count(gvar.get()); + CHECK(!is_kernel_launch || !is_call_extern) + << "Function " << gvar << " has multiple callees, " + << "and would need to be lowered into a call_extern at some call " + "sites, " + << "and a device kernel launch at others. " + << "This case is not yet supported."; + + if (is_kernel_launch || is_call_extern) { + func = + WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true)); + } + + if (is_kernel_launch) { + const auto &info = device_info_map_.at(gvar.get()); + + // Kernel launches provide an int32 error code to the caller, + // but do not accept any return type from the callee. + { + auto write_ptr = func.CopyOnWrite(); + write_ptr->ret_type = VoidType(); + write_ptr->body = ReturnRemover::Apply(write_ptr->body); + } + + func = + WithAttrs(std::move(func), + {{tvm::attr::kCallingConv, + Integer(tvm::CallingConv::kDeviceKernelLaunch)}, + {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, + {tvm::attr::kGlobalSymbol, info.global_symbol}}); + } + // @lei: workaround as we may require c host codegen, so we need to set the + // global symbol for cpu backend. + func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); + + const auto &info = device_info_map_.at(gvar.get()); + const auto &thread_extent = info.thread_extent; + func = WithAttr(std::move(func), "thread_extent", thread_extent); + if (info.dyn_shmem_size.defined()) { + func = WithAttr(std::move(func), "dyn_shared_memory_buf", + info.dyn_shmem_size.value()); + } + return func; + } + +private: + PrimExpr VisitExpr_(const CallNode *op) override { + auto node = Downcast(Parent::VisitExpr_(op)); + + auto *gvar = op->op.as(); + if (!gvar) + return std::move(node); + + auto it = device_info_map_.find(gvar); + ICHECK(it != device_info_map_.end()) + << "CallNode attempted subroutine call to " << gvar->name_hint + << ", but " << gvar->name_hint << " did not appear within the IRModule"; + const KernelInfo &dev_info = it->second; + + auto caller_target = current_target_.value(); + auto callee_target = dev_info.target; + + bool same_target = caller_target->str() == callee_target->str(); + + if (same_target) { + // Calls within the same target may be handled at codegen time + // as internal subroutine calls. + return std::move(node); + } + + bool same_device_type = caller_target->GetTargetDeviceType() == + callee_target->GetTargetDeviceType(); + if (same_device_type) { + // Calls to another target using the same device (e.g. LLVM + // calling a custom TIRToRuntime target) do not require a kernel + // launch, but need to be replaced with call_extern. + extern_function_call_.insert(gvar); + Array args; + args.push_back(StringImm(gvar->name_hint)); + for (const auto &arg : node->args) { + args.push_back(arg); + } + return Call(node->dtype, builtin::call_extern(), args); + } + + ICHECK(dev_info.launch_params.defined()) + << "CallNode attempted kernel launch to " << gvar->name_hint + << " on target " << dev_info.target << ", but subroutine " + << gvar->name_hint + << " did not have the tir::attr::kKernelLaunchParams attribute " + << "required for cross-target kernel launch"; + + // Collected kernel information may be in terms of the callee's + // arguments, but we need expressions for them in terms of the + // caller's parameters. The param_map allows substitution of + // parameter values into the thread extents, to generate + // expressions that are valid within the caller. + Map param_map = [&]() { + Map param_map; + CHECK_EQ(node->args.size(), dev_info.params.size()) + << "Function " << gvar->name_hint << " accepts " + << dev_info.params.size() + << " arguments as input, but is called using " << node->args.size() + << " arguments"; + for (size_t i = 0; i < node->args.size(); i++) { + param_map.Set(dev_info.params[i], node->args[i]); + } + return param_map; + }(); + + device_kernel_launch_.insert(gvar); + + Array call_args; + call_args.push_back(StringImm(dev_info.global_symbol)); + for (PrimExpr arg : node->args) { + call_args.push_back(arg); + } + for (const auto &launch_arg : dev_info.launch_args) { + call_args.push_back(Substitute(launch_arg, param_map)); + } + + auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; + + return Call(dtype, builtin::tvm_call_packed(), call_args); + } + + Optional current_target_; + std::unordered_map device_info_map_; + std::unordered_set device_kernel_launch_; + std::unordered_set extern_function_call_; +}; + +namespace transform { + +tvm::transform::Pass LowerDeviceKernelLaunch() { + auto pass_func = [](IRModule mod, + const tir::transform::PassContext &ctx) -> IRModule { + auto mutator = [&mod]() { + std::unordered_map device_info_map; + for (const auto &[gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + device_info_map[gvar.get()] = + DeviceInfoCollector::Collect(gvar, prim_func.value()); + } + } + return DeviceKernelMutator(std::move(device_info_map)); + }(); + + { + IRModule updates; + for (const auto &[gvar, base_func] : mod->functions) { + if (auto *ptr = base_func.as()) { + auto prim_func = mutator.RewriteKernelLaunchSite( + gvar, tvm::ffi::GetRef(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (!updates->functions.empty()) { + mod.CopyOnWrite()->Update(updates); + } + } + { + IRModule updates; + for (const auto &[gvar, base_func] : mod->functions) { + if (auto *ptr = base_func.as()) { + auto prim_func = mutator.UpdateKernelAttributes( + gvar, tvm::ffi::GetRef(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (!updates->functions.empty()) { + mod.CopyOnWrite()->Update(updates); + } + } + return mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, + "tl.LowerDeviceKernelLaunch", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch", + LowerDeviceKernelLaunch); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_device_storage_access_info.cc b/tilelang/original/src/transform/lower_device_storage_access_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..6dc46e98577615cb0456aecc940c4e1247a0d5ae --- /dev/null +++ b/tilelang/original/src/transform/lower_device_storage_access_info.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_device_storage_access.cc + * \brief Lower the special device storage access. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { +using namespace tir; + +using runtime::StorageRank; +using runtime::StorageScope; + +class StorageAccessInfoLower : public StmtExprMutator { +public: + Stmt VisitStmt_(const AllocateNode *op) final { + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" && + scope.tag != ".barrier" && scope.tag.find(".descriptor") != 0) { + auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); + ICHECK(info.defined()) + << "Cannot find memory info of " << scope.to_string(); + ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) + << "Double allocation of " << scope.to_string(); + storage_info_[op->buffer_var.get()] = info; + + // Lower allocate to device allocate when needed. + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + if (info->head_address.defined()) { + return LetStmt(op->buffer_var, info->head_address, op->body); + } else { + return op->body; + } + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + if (auto it = storage_info_.find(node->buffer->data.get()); + it != storage_info_.end() && !it->second->head_address.defined()) { + return node->body; + } else { + return std::move(node); + } + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + return MakeAccessPtr(op); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + +private: + // tvm_access_ptr + PrimExpr MakeAccessPtr(const CallNode *op) { + // Specially handle the buffer packed intrinsic + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + const VarNode *buffer = op->args[1].as(); + Var buffer_var = Downcast(op->args[1]); + PrimExpr offset = op->args[2]; + auto it = storage_info_.find(buffer); + if (it != storage_info_.end() && it->second.defined()) { + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, + it->second); + } + ICHECK(op->dtype.is_handle()); + // Change to address_of + return AddressOffset(buffer_var, dtype, offset); + } + + PrimExpr MakeTaggedAccessPtr(DataType ptr_type, const Var &buffer_var, + DataType dtype, const PrimExpr &offset, + const MemoryInfo &info) { + if (ptr_type.is_handle()) { + ICHECK(info->head_address.defined()) + << buffer_var << " is not adddressable."; + return AddressOffset(buffer_var, dtype, offset); + } + int dtype_bits = dtype.bits() * dtype.lanes(); + ICHECK_EQ(info->unit_bits % dtype_bits, 0); + return cast( + ptr_type, + analyzer_.Simplify( + offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); + } + // The storage scope of each buffer + std::unordered_map storage_info_; + // analyzer + arith::Analyzer analyzer_; +}; + +Stmt LowerStorageAccessInfo(Stmt stmt) { + return StorageAccessInfoLower()(std::move(stmt)); +} + +namespace transform { +using namespace tir::transform; + +Pass LowerDeviceStorageAccessInfo() { + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + n->body = StorageAccessInfoLower()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerDeviceStorageAccessInfo", + {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo", + LowerDeviceStorageAccessInfo); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_hopper_intrin.cc b/tilelang/original/src/transform/lower_hopper_intrin.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9c848ac91a54294ad64e482801583b76d792b50 --- /dev/null +++ b/tilelang/original/src/transform/lower_hopper_intrin.cc @@ -0,0 +1,229 @@ +/*! + * \file lower hopper intrin.cc + * \brief Lower Hopper intrinsics cuda GPU(sm90+) + */ + +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../runtime/runtime.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +#if (CUDA_MAJOR_VERSION >= 12) +class LowerHopperIntrin : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) { + PrimFuncNode *fptr = f.CopyOnWrite(); + LowerHopperIntrin substituter(disable_shuffle_elect); + fptr->body = substituter.VisitStmt(f->body); + Map> init_desc_arg_map; + // Collect prologue/epilogue statements for host-side setup/teardown + Array prologue_stmts; + Array epilogue_stmts; + for (const auto &[call, var] : substituter.desc_map_) { + // Should allocate 128 bytes for TensorMap on stack + Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), + {StringImm("tvm_ffi_any"), 16}); + Array init_desc_args; + if (call->op.same_as(create_tma_descriptor())) { + init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled)); + } else if (call->op.same_as(create_tma_im2col_descriptor())) { + init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col)); + } else { + CHECK(0) << call->op; + } + init_desc_args.push_back(var); + init_desc_args.insert(init_desc_args.end(), call->args.begin(), + call->args.end()); + // add to function attribute + Call init_desc = + Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); + // Accumulate TMA descriptor init into prologue + prologue_stmts.push_back(LetStmt(var, alloc_desc, Evaluate(init_desc))); + init_desc_arg_map.Set(var, init_desc_args); + } + f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map); + + // Additionally, if L2 persistent cache annotations were lowered earlier, + // materialize TVM FFI calls to set the stream access policy window. + if (f->attrs.defined() && f->attrs->dict.count("l2_persistent_map")) { + auto l2_map = + f->GetAttr>>("l2_persistent_map"); + if (l2_map.defined()) { + // Build a lookup from buffer name to Buffer object + std::unordered_map name2buf; + for (const auto &kv : f->buffer_map) { + name2buf.emplace(kv.second->name, kv.second); + } + for (const auto &kv : l2_map.value()) { + const std::string buf_name = kv.first; + const Array &args = kv.second; + if (name2buf.count(buf_name) == 0) { + continue; + } + const Buffer &buf = name2buf.at(buf_name); + // Build base pointer expression (read access) + PrimExpr base_ptr = buf.access_ptr(1); + // Args packed: func_name, base_ptr, num_bytes, hit_ratio + Array packed_args; + packed_args.push_back( + StringImm(tvm_cuda_stream_set_access_policy_window)); + packed_args.push_back(base_ptr); + // size_in_bytes (args[1]) then hit_ratio (args[0]) + ICHECK_GE(args.size(), 2); + packed_args.push_back(args[1]); + packed_args.push_back(args[0]); + prologue_stmts.push_back(Evaluate(Call( + DataType::Int(32), builtin::tvm_call_packed(), packed_args))); + } + // Add a single epilogue call to reset the access policy window and + // restore L2 limit + Array reset_args; + reset_args.push_back( + StringImm(tvm_cuda_stream_reset_access_policy_window)); + epilogue_stmts.push_back(Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), reset_args))); + } + } + + // Stitch prologue statements before the original body + if (!prologue_stmts.empty()) { + // Chain the Let/Evaluate statements sequentially + Stmt seq = prologue_stmts.size() == 1 ? prologue_stmts[0] + : SeqStmt(prologue_stmts); + fptr->body = SeqStmt({seq, fptr->body}); + } + if (!epilogue_stmts.empty()) { + Stmt seq_end = epilogue_stmts.size() == 1 ? epilogue_stmts[0] + : SeqStmt(epilogue_stmts); + fptr->body = SeqStmt({fptr->body, seq_end}); + } + return f; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + // Insert the prefetch TMA descriptor statement TO the beginning of the + // kernel + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + auto body = StmtExprMutator::VisitStmt(op->body); + if (prefetch_calls_.empty() && init_mbarrier_calls_.empty()) { + return AttrStmt(op->node, op->attr_key, op->value, body); + } else { + Array stmt_seq; + if (!init_mbarrier_calls_.empty()) { + auto alloc_mbarrier = + Evaluate(Call(DataType::Handle(), builtin::create_barriers(), + {static_cast(init_mbarrier_calls_.size())})); + stmt_seq.push_back(alloc_mbarrier); + } + + auto stmts = prefetch_calls_; + stmts.insert(stmts.end(), init_mbarrier_calls_.begin(), + init_mbarrier_calls_.end()); + PrimExpr condition; + if (!disable_shuffle_elect_) { + condition = Call(DataType::Bool(), tl_shuffle_elect(), {0}); + } else { + condition = EQ(iv->var, 0); + } + auto stmt_ = IfThenElse(condition, + stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]); + stmt_seq.push_back(stmt_); + if (!init_mbarrier_calls_.empty()) { + // Note from FlashAttention: + // Helps with visibility of barrier init operations across warps / + // cta / cluster Available as a separate function so as to batch + // inits across barriers and fence once Note : It must be composed + // with an appropriate sync instruction with the right scope to + // ensure visibility eg. __syncthreads() or a cluster_arrive() + + // cluster_wait() + Stmt mem_fence = Evaluate(Call( + DataType::Handle(), tvm::tl::ptx_fence_barrier_init(), {})); + stmt_seq.push_back(mem_fence); + Stmt mem_sync = + Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), + {StringImm("shared")})); + stmt_seq.push_back(mem_sync); + } + stmt_seq.push_back(body); + + prefetch_calls_.clear(); + init_mbarrier_calls_.clear(); + return AttrStmt(op->node, op->attr_key, op->value, SeqStmt(stmt_seq)); + } + } + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *call) final { + if (call->op.same_as(create_tma_descriptor()) || + call->op.same_as(create_tma_im2col_descriptor())) { + Var var; + auto iter = desc_map_.find(tvm::ffi::GetRef(call)); + if (iter != desc_map_.end()) { + var = iter->second; + } else { + String name = call->args[2].as().value()->name_hint; + var = Var(name + "_desc", + PointerType(PrimType(cuTensorMapType()), "grid_constant")); + desc_map_[tvm::ffi::GetRef(call)] = var; + prefetch_calls_.push_back( + Evaluate(Call(DataType::Handle(), builtin::call_extern(), + {StringImm("tl::prefetch_tma_descriptor"), var}))); + } + return var; + } else if (call->op.same_as(create_list_of_mbarrier())) { + ICHECK(init_mbarrier_calls_.empty()); + int num_barriers = static_cast(call->args.size()); + for (int i = 0; i < num_barriers; i++) { + PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i}); + init_mbarrier_calls_.push_back(Evaluate( + Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), + {mbarrier, call->args[i]}))); + } + return 0; + } else { + return StmtExprMutator::VisitExpr_(call); + } + } + +private: + Array prefetch_calls_; + Array init_mbarrier_calls_; + std::unordered_map desc_map_; + LowerHopperIntrin(bool disable_shuffle_elect) + : disable_shuffle_elect_(disable_shuffle_elect) {} + bool disable_shuffle_elect_; +}; + +using namespace tir::transform; + +tvm::transform::Pass LowerHopperIntrin() { + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool disable_shuffle_elect = + ctx->GetConfig(kDisableShuffleElect, Bool(false)).value(); + return LowerHopperIntrin::Substitute(f, disable_shuffle_elect); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin); +} +#endif // (CUDA_MAJOR_VERSION >= 12) + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_intrin.cc b/tilelang/original/src/transform/lower_intrin.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf312264d7dc7a25cb515e5c7f5c91a0d475b797 --- /dev/null +++ b/tilelang/original/src/transform/lower_intrin.cc @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Lower intrinsic calls and ops to device specific ir when possible. + * \file lower_intrin.cc + */ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/pattern_match.h" + +namespace tvm { +namespace tl { +using namespace tir; +using namespace ffi; + +class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { +public: + using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; + using FLowerGeneral = ffi::TypedFunction; + + IntrinInjecter(arith::Analyzer *analyzer, std::string target, + std::string mtriple = "") + : IRMutatorWithAnalyzer(analyzer) { + std::vector patterns; + patterns.push_back(target + ".FLowerIntrinsic"); + patterns.push_back(target + ".FLegalize"); + bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); + if (is_llvm_aarch64) { + patterns.push_back(target + ".aarch64.FLowerIntrinsic"); + patterns.push_back(target + ".aarch64.FLegalize"); + } + patterns.push_back("default.FLowerIntrinsic"); + patterns.push_back("default.FLegalize"); + + for (const std::string &pattern : patterns) + if (Op::HasAttrMap(pattern)) { + attr_maps_.push_back(Op::GetAttrMap(pattern)); + if (fma_ == nullptr) { + fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr); + } + } + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (auto *ptr_op = op->op.as()) { + for (const auto &f_attr_map : attr_maps_) { + FLowerGeneral f = f_attr_map.get(tvm::ffi::GetRef(ptr_op), nullptr); + if (f != nullptr) { + PrimExpr e = tvm::ffi::GetRef(op); + PrimExpr r = f(e); + ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; + if (!r.same_as(e)) { + r = this->VisitExpr(r); + if (r.defined()) { + return r; + } + } + } + } + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const AddNode *op) final { + if (const MulNode *mb = op->b.as()) { + return MakeFMA(mb->a, mb->b, op->a, op); + } else if (const MulNode *ma = op->a.as()) { + return MakeFMA(ma->a, ma->b, op->b, op); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + // We use floordiv for integer analysis, + // but will need to lower them to native truncdiv instructions + PrimExpr VisitExpr_(const FloorDivNode *op) final { + auto e = tvm::ffi::GetRef(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + op = ret.as(); + if (op == nullptr) + return ret; + int shift; + const DataType &dtype = op->dtype; + ICHECK(dtype.is_int() || dtype.is_uint()); + + // lower (a + 31) // 512 to (a + 31) >> 5 + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { + // lower to right shift if possible. + return op->a >> make_const(dtype, shift); + } + + if (analyzer_->CanProveGreaterEqual(op->b, 0)) { + // Common path, positive divisor + if (analyzer_->CanProveGreaterEqual(op->a, 0) || + analyzer_->CanProveGreaterEqual(e, 0)) { + return truncdiv(op->a, op->b); + } + + // NOTE: Disabled due to integer overflow risk in `a + b * c`. + // The transformation `floordiv(a,b) -> truncdiv(a + b*c, b) - c` + // may overflow when `a` is near type limit and `c` is large, + // producing incorrect results. + + // If the numerator's lower bound is known, express the floordiv + // in terms of truncdiv using only positive operands. + /* + arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); + if (const_int_bound->min_value < 0 && + const_int_bound->min_value > + -(Downcast(tvm::max_value(op->a->dtype.element_of())) + ->value)) { + // The goal is to write floordiv(a,b) in terms of truncdiv, without + // using negative operands. + // + // For any integer c + // + // floordiv(a,b) == floordiv(a + b*c - b*c, b) + // == floordiv(a + b*c, b) - c + // + // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of + // truncdiv as follows. + // + // c == ceildiv(-a_min,b) + // == floordiv(-a_min + (b-1), b) + // == truncdiv(-a_min + (b-1), b) + // + // When substituted into `a + b*c`, this results in a positive argument. + // + // a + b*c + // == a + b*ceildiv(-a_min,b) + // == a - b*floordiv(a_min,b) + // >= a - b*floordiv(a,b) + // == floormod(a, b) + // >= 0 + // + // Since the argument is positive, this allows floordiv to be written as + // followed. + // + // floordiv(a,b) + // == floordiv(a + b*c, b) - c + // == truncdiv(a + b*c, b) - c + IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); + PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); + PrimExpr offset_numerator = + analyzer_->Simplify(op->a + op->b * ceildiv); + return truncdiv(offset_numerator, op->b) - ceildiv; + } + */ + + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; + PrimExpr rdiv = truncdiv(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); + // condition on b >= 0. + // truncmod(a, b) < 0 will implies ceildiv, + // So we need to correct these cases. + if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && + support_bitwise_op_) { + // equivalent to rdiv + (rmod >= 0 ? 0: -1); + return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); + } else { + return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); + } + + } else { + if (dtype.is_float()) { + // floor(a / b) + return VisitExpr_(tvm::floor(op->a / op->b).as()); + } else { + // uncommon case + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; + auto rmod = tir::Var("rmod", dtype); + auto rdiv = tir::Var("rdiv", dtype); + // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) + // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) + PrimExpr let_rdiv = tir::Let( + rdiv, truncdiv(op->a, op->b), + tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), + rdiv, rdiv - make_const(dtype, 1))); + return Let(rmod, truncmod(op->a, op->b), let_rdiv); + } + } + } + + PrimExpr VisitExpr_(const FloorModNode *op) final { + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + op = ret.as(); + if (op == nullptr) + return ret; + // Lower floordiv to native truncdiv. + int shift; + const DataType &dtype = op->dtype; + ICHECK(dtype.is_int() || dtype.is_uint()); + + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { + // lower to masking if possible. + int64_t mask = + (static_cast(1) << static_cast(shift)) - 1; + return op->a & make_const(dtype, mask); + } + + if (analyzer_->CanProveGreaterEqual(op->b, 0)) { + // Common pass, positive divisor + if (analyzer_->CanProveGreaterEqual(op->a, 0)) { + return truncmod(op->a, op->b); + } + + // NOTE: Disabled due to integer overflow risk in `a + b * c`. + // The transformation `floordiv(a,b) -> truncdiv(a + b*c, b) - c` + // may overflow when `a` is near type limit and `c` is large, + // producing incorrect results. + + // If the numerator's lower bound is known, express the floormod + // in terms of truncmod using only positive operands. + /* + arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); + if (const_int_bound->min_value < 0 && + const_int_bound->min_value > + -(Downcast(tvm::max_value(op->a->dtype.element_of())) + ->value)) { + // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, + // without using negative operands. + // + // For any integer c + // + // floormod(a, b) == floormod(a + b*c, b) + // + // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of + // truncdiv as follows. + // + // c == ceildiv(-a_min,b) + // == floordiv(-a_min + (b-1), b) + // == truncdiv(-a_min + (b-1), b) + // + // When substituted into `a + b*c`, this results in a positive argument. + // + // a + b*c + // == a + b*ceildiv(-a_min,b) + // == a - b*floordiv(a_min,b) + // >= a - b*floordiv(a,b) + // == floormod(a, b) + // >= 0 + // + // Since the argument is positive, this allows floordiv to be written as + // followed. + // + // floormod(a,b) + // == floormod(a + b*c, b) + // == truncmod(a + b*c, b) + IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); + PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); + PrimExpr offset_numerator = + analyzer_->Simplify(op->a + op->b * ceildiv); + return truncmod(offset_numerator, op->b); + } + */ + + DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident"; + // NOTE:condition on b >= 0. + // mod(a, b) < 0 will imply we are doing ceildiv, + // So we need to correct these cases. + PrimExpr rmod = truncmod(op->a, op->b); + if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && + support_bitwise_op_) { + // (rmod >> shift) & b + // -> (rmod >= 0 ? 0: -1) & b + // -> rmod >= 0 ? 0 : b + return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); + } else { + return tir::Select(rmod >= 0, rmod, rmod + op->b); + } + + } else { + if (dtype.is_float()) { + // a - floor(a / b) * b + return op->a - + (VisitExpr_(tvm::floor(op->a / op->b).as()) * op->b); + } else { + // uncommon case + DLOG(INFO) + << "LowerFloorMod: Cannot decide the sign of divsor and divident"; + auto rmod = tir::Var("rmod", dtype); + // b > 0 && rmod >= 0 -> rmod + // b > 0 && rmod < 0 -> rmod + b + // b < 0 && rmod < 0 -> rmod + // b < 0 && rmod > 0 -> rmod + b + return Let(rmod, truncmod(op->a, op->b), + Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), + rmod, rmod + op->b)); + } + } + } + + PrimExpr VisitExpr_(const MaxNode *op) final { + using namespace arith; + PVar x, y; + PVar c; + auto e = tvm::ffi::GetRef(op); + if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && + analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { + return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const EQNode *op) final { + using namespace arith; + PVar x, y; + auto e = tvm::ffi::GetRef(op); + if ((floormod(x, y) == 0).Match(e)) { + return VisitExpr((truncmod(x, y) == 0).Eval()); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const NENode *op) final { + using namespace arith; + PVar x, y; + auto e = tvm::ffi::GetRef(op); + if ((floormod(x, y) != 0).Match(e)) { + return VisitExpr((truncmod(x, y) != 0).Eval()); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + +private: + PrimExpr SwapBroadcastCast(const PrimExpr &e) { + // Try to change broadcast(cast(x)) to cast(broadcast(x)) + // For some targets, LLVM will generate more efficient FMA + // instruction with the latter. For example, vmla vs. vmlal + // on ARM. + if (const BroadcastNode *bcast = e.as()) { + if (const CastNode *cast = bcast->value.as()) { + auto should_swap = [&]() { + // Maintain behaviour (int8 -> int16, fp16 -> fp32). + if (cast->dtype.bits() == cast->value.dtype().bits() * 2) { + return true; + } + // Check both operands are integer-like. + if (!cast->dtype.is_uint() && !cast->dtype.is_int()) { + return false; + } + if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) { + return false; + } + // If both are integer-like, swap if we have a widening cast. + return cast->dtype.bits() > cast->value.dtype().bits(); + }; + + if (should_swap()) { + PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes); + return Cast(bcast->dtype, new_bcast); + } + } + } + return e; + } + + PrimExpr MakeFMA(const PrimExpr &a, const PrimExpr &b, const PrimExpr &c, + const AddNode *op) { + // emit fma instruction: a * b + c + PrimExpr lhs = SwapBroadcastCast(a); + PrimExpr rhs = SwapBroadcastCast(b); + + if (fma_ != nullptr && op->dtype.is_float()) { + PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); + if (r.defined()) + return this->VisitExpr(r); + } else { + if (!lhs.same_as(a) || !rhs.same_as(b)) { + PrimExpr mul = this->VisitExpr(Mul(lhs, rhs)); + return Add(mul, this->VisitExpr(c)); + } + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + // attribute maps, shared only when FLegalize == FLowerIntrinsic + std::vector> attr_maps_; + FLowerGeneral fma_{nullptr}; + bool support_bitwise_op_{true}; +}; + +Stmt LowerIntrinStmt(Stmt stmt, const std::string &target) { + arith::Analyzer analyzer; + return IntrinInjecter(&analyzer, target)(std::move(stmt)); +} + +namespace transform { + +tir::transform::Pass LowerIntrin() { + using namespace tir::transform; + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto *n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; + arith::Analyzer analyzer; + auto mtriple = target.value()->GetAttr("mtriple", ""); + n->body = IntrinInjecter(&analyzer, target.value()->kind->name, + mtriple.value())(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin); +} + +} // namespace transform + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_l2_persistent_annotation.cc b/tilelang/original/src/transform/lower_l2_persistent_annotation.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f7be710de3ee160692d98b058417b11149f2edc --- /dev/null +++ b/tilelang/original/src/transform/lower_l2_persistent_annotation.cc @@ -0,0 +1,107 @@ +/*! + * \file lower_l2_persistent_annotation.cc + * \brief Lower L2 persistent annotation + */ + +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../runtime/runtime.h" + +namespace tvm { +namespace tl { + +namespace attr { +// BlockAttr, Containing the layout for all the buffers in the block +constexpr const char *kL2RatioMap = "l2_hit_ratio_map"; +constexpr const char *kL2PersistentMap = "l2_persistent_map"; +} // namespace attr + +using namespace tir; + +class LowerL2Persistent : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f) { + PrimFuncNode *fptr = f.CopyOnWrite(); + LowerL2Persistent substituter; + // Trace the buffer map for tvm_access_ptr + substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end()); + for (const auto &[_, buffer] : f->buffer_map) { + substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + fptr->body = substituter.VisitStmt(f->body); + Map> init_l2_persistent_map; + for (auto [buffer, hit_ratio] : substituter.hit_ratio_map_) { + Array l2_persistent_arguments; + // Argument 0: hit ratio + // Argument 1: size in bytes + l2_persistent_arguments.push_back(hit_ratio); + PrimExpr size_in_bytes = IntImm(DataType::Int(64), buffer->dtype.bytes()); + for (auto dim : buffer->shape) { + size_in_bytes = size_in_bytes * dim; + } + l2_persistent_arguments.push_back(size_in_bytes); + init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments); + } + if (!init_l2_persistent_map.empty()) { + f = WithAttr(std::move(f), attr::kL2PersistentMap, + init_l2_persistent_map); + } + return f; + } + + Stmt VisitStmt_(const BlockNode *op) final { + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : op->alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + for (auto buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + if (op->annotations.count(attr::kL2RatioMap)) { + auto hit_ratio_map = op->annotations.at(attr::kL2RatioMap) + .as>() + .value(); + for (auto [buffer_var, hit_ratio] : hit_ratio_map) { + Buffer buffer = buffer_data_to_buffer_.at(buffer_var); + hit_ratio_map_.Set(buffer, hit_ratio); + } + } + auto block = Downcast(StmtExprMutator::VisitStmt_(op)); + auto block_ptr = block.CopyOnWrite(); + block_ptr->annotations.erase(attr::kL2RatioMap); + return block; + } + +private: + // Mapping from data Var of a Buffer to Buffer, for lookup + Map buffer_data_to_buffer_; + std::unordered_map buffer_map_; + Map hit_ratio_map_; + LowerL2Persistent() = default; +}; + +using namespace tir::transform; + +tvm::transform::Pass LowerL2Persistent() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return LowerL2Persistent::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_opaque_block.cc b/tilelang/original/src/transform/lower_opaque_block.cc new file mode 100644 index 0000000000000000000000000000000000000000..76dc36a6a3d774b725338e7befd043401c57b133 --- /dev/null +++ b/tilelang/original/src/transform/lower_opaque_block.cc @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_opaque_block.cc + */ + +#include +#include +#include +#include + +#include +#include + +#include "../op/builtin.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace tir::attr; +/*! + * \brief Remove Block to ensure that the TIR can not be scheduled again. + */ +class OpaqueBlockLower : public StmtExprMutator { +public: + static PrimFunc Rewrite(PrimFunc f) { + auto fptr = f.CopyOnWrite(); + OpaqueBlockLower lower; + if (auto existing = + fptr->attrs.GetAttr>(tl::attr::kLocalVarInit)) { + lower.local_var_init_map_ = existing.value(); + } + lower.storage_align_ = CollectStorageAlignAnnotation(fptr->body); + fptr->body = lower(std::move(fptr->body)); + if (!lower.local_var_init_map_.empty()) { + f = WithAttr(std::move(f), tl::attr::kLocalVarInit, + lower.local_var_init_map_); + } + return f; + } + +private: + Stmt VisitStmt_(const BlockRealizeNode *op) final { + // We have convert blocks into opaque blocks in previous passes. + ICHECK(op->iter_values.empty()) + << "Non-opaque blocks are not allowed in FlattenBuffer. Please " + "call pass ConvertBlocksToOpaque before."; + // Step 1. Visit the body + Block new_block = Downcast(this->VisitStmt(op->block)); + PrimExpr predicate = this->VisitExpr(op->predicate); + // Step 2. Transform the `predicate` to if-then-else + Stmt body = new_block->body; + if (!is_one(predicate)) { + body = IfThenElse(predicate, std::move(body)); + } + // Step 3. Handle annotations, block annotations are not preserved by + // default. + std::vector> pragma_attrs; + HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true, + new_block->alloc_buffers); + + // Step 4. Handle allocations in reverse order + for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { + const Buffer &buffer = new_block->alloc_buffers[i - 1]; + Array allocation_shape = GetBufferAllocationShape(buffer); + body = DeclBuffer(buffer, std::move(body)); + Map allocate_annotations; + auto it = storage_align_.find(buffer->data); + if (it != storage_align_.end()) { + StorageAlignAnnotation allocate_aligns; + for (auto tuple : it->second) { + tuple.Set<0>(-1); + allocate_aligns.push_back(tuple); + } + allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns); + } + auto init_it = local_var_init_map_.find(buffer->data); + if (init_it != local_var_init_map_.end()) { + const PrimExpr &init = (*init_it).second; + allocate_annotations.Set(tl::attr::kLocalVarInit, init); + } + body = Allocate(buffer->data, buffer->dtype, allocation_shape, + const_true(), std::move(body), allocate_annotations); + } + // Step 5. Insert attribute statements converted from pragmas + for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { + body = AttrStmt(Integer(0), it->first, it->second, std::move(body)); + } + return body; + } + Stmt VisitStmt_(const BlockNode *op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + if (block->annotations.count("stmt_group")) { + return block->body; + } + return block; + } + + Stmt VisitStmt_(const ForNode *op) final { + // Step 1. Update unit loop info. + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); + if (is_one(extent) && IsEffectivelyEmptyAnnotation(op->annotations)) { + // handling unit loop + unit_loop_vars_[op->loop_var] = min; + } + // Step 2. Visit recursively + Stmt body = this->VisitStmt(op->body); + // Step 3. Handle annotations + std::vector> pragma_attrs; + Map new_annotations = + HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false); + // Step 4. Create new For loop accordingly + if (op->kind == ForKind::kThreadBinding) { + // Case 1. Thread binding + ICHECK(op->thread_binding.defined()); + String thread_tag = op->thread_binding.value()->thread_tag; + body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); + } else if (is_one(extent) && + IsEffectivelyEmptyAnnotation(op->annotations)) { + // Case 2. Unit loop + return body; + } else { + // Case 3. An ordinary loop + body = For(op->loop_var, std::move(min), std::move(extent), op->kind, + std::move(body), std::nullopt, new_annotations); + } + // Step 5. Insert nested attrs + for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { + body = AttrStmt(op->loop_var, it->first, it->second, std::move(body)); + } + return body; + } + + // Treat annotations as empty if they are truly empty or contain only + // the unroll hint `pragma_unroll_explicit`. This allows unit-length + // loops produced by unroll pragmas to be simplified away. + bool + IsEffectivelyEmptyAnnotation(const Map &annotations) const { + if (annotations.empty()) { + return true; + } + if (annotations.size() == 1) { + auto it = annotations.find(tir::attr::pragma_unroll_explicit); + if (it != annotations.end()) { + return true; + } + } + return false; + } + + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = tvm::ffi::GetRef(op); + auto it = unit_loop_vars_.find(var); + if (it == unit_loop_vars_.end()) { + return var; + + } else { + PrimExpr expr = it->second; + if (expr.dtype() != var.dtype()) { + expr = tvm::cast(var.dtype(), std::move(expr)); + } + return expr; + } + } + + static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, + const String &thread_tag, Stmt body) { + IterVar iter_var(/*dom=*/Range::FromMinExtent(std::move(min), extent), + /*var=*/std::move(var), + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? tir::attr::virtual_thread + : tir::attr::thread_extent; + return AttrStmt(/*node=*/std::move(iter_var), + /*attr_key=*/std::move(attr_key), + /*value=*/std::move(extent), + /*body=*/std::move(body)); + } + + /*! \brief Convert attr value from annotation map into PrimExpr. */ + PrimExpr ConvertAttrValue(const String &key, const Any &obj) { + if (obj == nullptr) { + return PrimExpr(); + } else if (auto expr = obj.try_cast()) { + return expr.value(); + } else if (auto str = obj.try_cast()) { + return std::move(StringImm(str.value())); + } else { + LOG(FATAL) << "Illegal attribute of key " << key << ", value type " + << obj.GetTypeKey() << " not supported"; + return PrimExpr(); + } + } + + /*! + * \brief Helper to handle annotation dict. + * (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They + * are lowered to `AttrStmt` by legacy TE schedule convention. + * (2) the non-pragma loop annotations are preserved + * (3) the non-pragma block annotations are dropped + * \return New annotation dict with preserved keys. Also update pragma attr + * pairs ordered by key. + */ + Map + HandleAnnotations(const Map &annotations, + std::vector> *pragma_attrs, + bool is_block, + const Array &alloc_buffers = Array()) { + Map preserved_annotations; + pragma_attrs->clear(); + for (const auto &kv : annotations) { + const String &key = kv.first; + if (tir::attr::IsPragmaKey(key)) { + pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); + } else if (key == tl::attr::kLocalVarInit) { + if (auto local_init_map = kv.second.try_cast>()) { + for (const auto &pair : local_init_map.value()) { + local_var_init_map_.Set(pair.first, pair.second); + } + } else if (auto init_expr = kv.second.try_cast()) { + ICHECK(is_block) << "`" << tl::attr::kLocalVarInit + << "` on non-block annotations is not supported"; + Buffer target = ResolveLocalVarBuffer(alloc_buffers); + if (!target.defined()) { + LOG(WARNING) << "Failed to resolve buffer for `" + << tl::attr::kLocalVarInit << "` annotation"; + continue; + } + local_var_init_map_.Set(target->data, init_expr.value()); + } else { + LOG(FATAL) << "Expected `" << tl::attr::kLocalVarInit + << "` to be a PrimExpr or Map, but got " + << kv.second.GetTypeKey(); + } + } else if (!is_block) { + // the loop annotation is preserved + preserved_annotations.Set(key, kv.second); + } + } + std::sort( + pragma_attrs->begin(), pragma_attrs->end(), + [](const auto &p1, const auto &p2) { return p1.first < p2.first; }); + return preserved_annotations; + } + + Buffer ResolveLocalVarBuffer(const Array &alloc_buffers) const { + for (const Buffer &buffer : alloc_buffers) { + std::string scope = buffer.scope(); + if (scope.find("local.var") != std::string::npos) { + return buffer; + } + } + if (!alloc_buffers.empty()) { + return alloc_buffers.back(); + } + return Buffer(); + } + + /*! \brief Record the loop_var and loop start value of unit loops, whose + * extent is one. */ + std::unordered_map unit_loop_vars_; + + /*! \brief Attr keys to preserve into loop annotations. */ + std::unordered_set preserved_annotations_; + + /*! \brief The map from buffer var to its storage alignment information. */ + std::unordered_map storage_align_; + + /*! \brief Local var initializers collected from block annotations. */ + Map local_var_init_map_; +}; + +PrimFunc TLLowerOpaqueBlock(PrimFunc f) { + return OpaqueBlockLower::Rewrite(std::move(f)); +} + +tir::transform::Pass LowerOpaqueBlock() { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return TLLowerOpaqueBlock(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_shared_barrier.cc b/tilelang/original/src/transform/lower_shared_barrier.cc new file mode 100644 index 0000000000000000000000000000000000000000..991676cb8d7a26297b795f33f364fa9dc934e53d --- /dev/null +++ b/tilelang/original/src/transform/lower_shared_barrier.cc @@ -0,0 +1,214 @@ +/*! + * \file lower_shared_barrier.cc + * \brief Convert shared.barrier buffers to plain shared + ptx init. + */ +#include "../op/builtin.h" +#include "tvm/ir/type.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/stmt.h" +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class SharedBarrierRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt body, bool disable_shuffle_elect = false) { + SharedBarrierRewriter rewriter(disable_shuffle_elect); + return rewriter(std::move(body)); + } + +private: + SharedBarrierRewriter(bool disable_shuffle_elect) + : disable_shuffle_elect_(disable_shuffle_elect) {} + + Stmt VisitStmt_(const BlockNode *op) final { + Block block = tvm::ffi::GetRef(op); + Array alloc_buffers = op->alloc_buffers; + + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + + Array barrier_buffers; + + for (const auto &[data, buffer] : buffer_map_) { + const auto *ptr_type = + buffer->data->type_annotation.as(); + auto storage_scope = ptr_type->storage_scope; + ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType"; + if (storage_scope == "shared.barrier") { + barrier_buffers.push_back(buffer); + } + } + + if (barrier_buffers.empty()) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK(thread_var_.defined()) << "thread_var_ is not defined"; + + for (auto buffer : barrier_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + /* + Transform the barrier buffers to new allocations + transform: + data_is_ready = T.alloc_buffer((128,), "uint64", scope="shared.barrier") + compute_is_done = T.alloc_buffer((128,), "uint64", + scope="shared.barrier") + + into: + data_is_ready = T.alloc_buffer((1,), "uint64", scope="shared") + compute_is_done = T.alloc_buffer((1,), "uint64", scope="shared") + + if tx == 0: + T.ptx_init_barrier_thread_count(data_is_ready[0], 128) + T.ptx_init_barrier_thread_count(compute_is_done[0], 128) + */ + + // 2. create new buffers + Array new_buffers; + for (auto buffer : barrier_buffers) { + auto data = buffer->data; + auto new_buffer = Buffer(data, buffer->dtype, Array({1}), + Array({1}), PrimExpr(0), buffer->name, + buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); + new_buffers.push_back(new_buffer); + buffer_remap_.Set(buffer, new_buffer); + } + + // remove the barrier buffers + alloc_buffers.MutateByApply([this](Buffer buf) { + if (buffer_remap_.find(buf) != buffer_remap_.end()) { + return buffer_remap_.at(buf); + } + return buf; + }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } else { + return StmtExprMutator::VisitStmt_(op); + } + + // 3. create init calls for new buffers + Array init_mbarrier_calls_; + for (auto buffer : barrier_buffers) { + auto data = buffer->data; + auto old_buffer = buffer_data_to_buffer_.at(data); + auto new_buffer = buffer_remap_.at(old_buffer); + auto count = old_buffer->shape[0]; + + auto call = + Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), + {BufferLoad(new_buffer, {0}), PrimExpr(count)}); + init_mbarrier_calls_.push_back(Evaluate(call)); + } + if (init_mbarrier_calls_.empty()) + return block; + + Array new_body; + PrimExpr condition; + if (!disable_shuffle_elect_) { + condition = Call(DataType::Bool(), tl_shuffle_elect(), {0}); + } else { + condition = EQ(thread_var_->var, 0); + } + new_body.push_back(IfThenElse(condition, + init_mbarrier_calls_.size() == 1 + ? init_mbarrier_calls_.back() + : SeqStmt(init_mbarrier_calls_), + Stmt())); + new_body.push_back( + Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), + {StringImm("shared")}))); + new_body.push_back(block->body); + + block.CopyOnWrite()->body = SeqStmt(new_body); + + return StmtExprMutator::VisitStmt_(block.get()); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto buffer = load->buffer; + if (buffer_remap_.count(buffer)) { + auto new_buffer = buffer_remap_[load->buffer]; + return BufferLoad(new_buffer, load->indices); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto buffer = store->buffer; + if (buffer_remap_.count(buffer)) { + auto new_buffer = buffer_remap_[store->buffer]; + return BufferStore(new_buffer, store->value, store->indices); + } + return store; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + return StmtExprMutator::VisitStmt_(op); + } + + // This is a workaround for cpu backend, + // we need to define a thread_var for the serial loop. + IterVar thread_var_; + Map buffer_data_to_buffer_; + Map buffer_remap_; + // Mapping from data Var of a Buffer to Buffer, for lookup + std::unordered_map buffer_map_; + // Disable shuffle elect for the warp specialized kernel + bool disable_shuffle_elect_; +}; + +PrimFunc LowerSharedBarrier(PrimFunc f, bool disable_shuffle_elect) { + f.CopyOnWrite()->body = + SharedBarrierRewriter::Rewrite(f->body, disable_shuffle_elect); + return f; +} + +namespace transform { +using namespace tir::transform; + +tvm::transform::Pass LowerSharedBarrier() { + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool disable_shuffle_elect = + ctx->GetConfig(kDisableShuffleElect, Bool(false)).value(); + return tl::LowerSharedBarrier(std::move(f), disable_shuffle_elect); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_shared_tmem.cc b/tilelang/original/src/transform/lower_shared_tmem.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a3ad187e9ec5fa5da19f06e30129e1f2f18b49a --- /dev/null +++ b/tilelang/original/src/transform/lower_shared_tmem.cc @@ -0,0 +1,321 @@ +/*! + * \file lower_shared_tmem.cc + * \brief Convert shared.tmem buffers to plain shared + ptx init, and do + * coordinate translation (from logical address to physical address) + */ +#include "../op/builtin.h" +#include "../target/utils.h" +#include "tvm/ir/type.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/stmt.h" +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class SharedTmemRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt body) { + SharedTmemRewriter rewriter; + return rewriter(body); + } + +private: + Stmt VisitStmt_(const BlockNode *op) final { + Block block = tvm::ffi::GetRef(op); + Array alloc_buffers = op->alloc_buffers; + if (op->annotations.count(attr::kLayoutMap)) { + auto layout_map = op->annotations.Get(attr::kLayoutMap); + ICHECK(layout_map) << "layout map is not defined"; + layout_map_ = layout_map->as>().value(); + } + + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + + Array tmem_buffers; + + for (const auto &[data, buffer] : buffer_map_) { + const auto *ptr_type = + buffer->data->type_annotation.as(); + auto storage_scope = ptr_type->storage_scope; + ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType"; + if (storage_scope == "shared.tmem") { + tmem_buffers.push_back(buffer); + } + } + + if (tmem_buffers.empty()) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK(thread_var_.defined()) << "thread_var_ is not defined"; + + for (auto buffer : tmem_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + /* + Transform the tmem buffers to new allocations + transform: + tmem_buf0 = T.alloc_buffer((128, 128,), "uint64", + scope="shared.tmem") + tmem_buf1 = T.alloc_buffer((128, 128,), "uint64", + scope="shared.tmem") + + into: + tmem_buf0 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr") + tmem_buf1 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr") + + if tx == 0: + T.ptx_init_tensor_memory(tmem_buf0[0], 128) + T.ptx_init_tensor_memory(tmem_buf1[0], 128) + */ + // 1. create new data vars + Array new_data_vars; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + if (var_remap_.count(data)) + continue; + auto new_data = + Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared")); + var_remap_.Set(data, new_data); + new_data_vars.push_back(new_data); + } + + // 2. create new buffers + Array new_buffers; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + ICHECK(var_remap_.find(data) != var_remap_.end()) + << "data not found in var_remap_"; + auto new_data = var_remap_.at(data); + auto new_buffer = Buffer(new_data, tmem_dtype_, Array({1}), + Array({1}), PrimExpr(0), buffer->name, + buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); + new_buffers.push_back(new_buffer); + buffer_remap_.Set(buffer, new_buffer); + buffer_data_to_buffer_.Set(new_data, new_buffer); + } + + // remove the tmem buffers + alloc_buffers.MutateByApply([this](Buffer buf) { + if (buffer_remap_.find(buf) != buffer_remap_.end()) { + return buffer_remap_.at(buf); + } + return buf; + }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } else { + return StmtExprMutator::VisitStmt_(op); + } + + // 3. create init & dealloc calls for new buffers + std::vector init_mtmem_calls_; + std::vector dealloc_tmem_calls_; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + auto old_buffer = buffer_data_to_buffer_.at(data); + auto new_buffer = buffer_remap_.at(old_buffer); + + // Tmem physical coord range analysis + ICHECK(old_buffer->shape.size() == 2); + + auto analyzer = std::make_shared(); + arith::ConstIntBound phy_col_bounds = + analyzer->const_int_bound(old_buffer->shape[1]); + int num_cols_required = phy_col_bounds->max_value; + ICHECK(num_cols_required <= 512) + << "The number of columns required for tmem buffer " + << old_buffer->name << " is " << num_cols_required + << ", which exceeds the maximum of 512 columns"; + + int num_cols_allocated = 32; // Align num_cols_allocated to power of 2 + for (; num_cols_allocated < num_cols_required; num_cols_allocated *= 2) + ; + + auto new_buffer_access = new_buffer.access_ptr(1, DataType::Handle(), 1, + PrimExpr(0), PrimExpr(1)); + auto alloc_call = Call(DataType::Handle(), tl::ptx_init_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}); + init_mtmem_calls_.push_back(Evaluate(alloc_call)); + auto dealloc_call = + Call(DataType::Handle(), tl::ptx_deallocate_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}); + dealloc_tmem_calls_.push_back(Evaluate(dealloc_call)); + } + auto compare_by_buffer_name = [&](const Stmt &a, const Stmt &b) { + auto call_a = a.as()->value.as(); + auto call_b = b.as()->value.as(); + auto num_cols_a = call_a->args[1].as()->value; + auto num_cols_b = call_b->args[1].as()->value; + return num_cols_a > num_cols_b; + }; + std::sort(init_mtmem_calls_.begin(), init_mtmem_calls_.end(), + compare_by_buffer_name); + + Array new_body; + auto target = Target::Current(); + auto warp_size = TargetGetWarpSize(target); + auto thread_var_div_warp_size = + FloorDiv(thread_var_->var, IntImm(thread_var_->var->dtype, warp_size)); + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + init_mtmem_calls_.size() > 1 + ? SeqStmt(init_mtmem_calls_) + : init_mtmem_calls_.back(), + Stmt())); + new_body.push_back( + Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), + {StringImm("shared")}))); + new_body.push_back(block->body); + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + dealloc_tmem_calls_.size() > 1 + ? SeqStmt(dealloc_tmem_calls_) + : dealloc_tmem_calls_.back(), + Stmt())); + + auto block_ptr = block.CopyOnWrite(); + block_ptr->annotations.erase(attr::kLayoutMap); + block_ptr->body = SeqStmt(new_body); + + return StmtExprMutator::VisitStmt_(block.get()); + } + + PrimExpr GetTmemOffset(const Buffer &buffer, const Array &indices) { + ICHECK(buffer->shape.size() == 2); + ICHECK(indices.size() == 2); + ICHECK(layout_map_.defined()); + ICHECK(layout_map_.count(buffer)) + << "The layout of tmem buffer " << buffer->name + << " is not defined in the layout map"; + auto layout = layout_map_[buffer]; + ICHECK(layout.defined()); + Array tmem_phy_coords = layout->Forward(indices); + PrimExpr result = + tmem_phy_coords[0] << 16 | + tmem_phy_coords + [1]; // https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-memory-addressing + return result; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + // Translate tmem[logical_row, logical_col] to tmem[0] + tmem_offset + // Where + // - (logical_row, logical_col) is the logical address in the tmem buffer + // - tmem[0] is the base address allocated for the tmem buffer + // - tmem_offset = tmem_phy_coords[0]<<16 | tmem_phy_coords[1] + // where tmem_phy_coords = layout.Forward(logical_row, logical_col) + // is the physical address in the tmem buffer + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto buffer = load->buffer; + auto indices = load->indices; + + if (buffer_remap_.count(buffer)) { + auto new_buffer = buffer_remap_[load->buffer]; + return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices); + } else if (var_remap_.count(buffer->data)) { + auto new_buffer = Buffer( + var_remap_[buffer->data], tmem_dtype_, buffer->shape, buffer->strides, + buffer->elem_offset, buffer->name, buffer->data_alignment, + buffer->offset_factor, buffer->buffer_type); + return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto buffer = store->buffer; + ICHECK(buffer.scope() != "shared.tmem") + << "We should never directly store data into tmem!"; + return store; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + Var buffer_data = Downcast(op->args[1]); + if (!var_remap_.count(buffer_data)) { + return StmtExprMutator::VisitExpr_(op); + } + Var new_data = var_remap_[buffer_data]; + return Call( + op->dtype, op->op, + {op->args[0], new_data, op->args[2], op->args[3], op->args[4]}); + } + auto expr = StmtExprMutator::VisitExpr_(op); + return expr; + } + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = tvm::ffi::GetRef(op); + if (var_remap_.count(var)) { + return var_remap_[var]; + } + return var; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + return StmtExprMutator::VisitStmt_(op); + } + + // Datatypes for tmem + const DataType tmem_dtype_ = DataType::UInt(32); + // This is a workaround for cpu backend, + // we need to define a thread_var for the serial loop. + IterVar thread_var_; + Map var_remap_; + Map buffer_data_to_buffer_; + Map buffer_remap_; + // Mapping from data Var of a Buffer to Buffer, for lookup + std::unordered_map buffer_map_; + Map layout_map_; +}; + +PrimFunc LowerSharedTmem(PrimFunc f) { + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) << "LowerSharedTmem: Require the target attribute"; + SharedTmemRewriter rewriter; + f.CopyOnWrite()->body = rewriter.Rewrite(f->body); + return f; +} + +namespace transform { +using namespace tir::transform; + +tvm::transform::Pass LowerSharedTmem() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return tl::LowerSharedTmem(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_thread_allreduce.cc b/tilelang/original/src/transform/lower_thread_allreduce.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc0fbeb851bf9377ed2c5753cdab328d56c7488b --- /dev/null +++ b/tilelang/original/src/transform/lower_thread_allreduce.cc @@ -0,0 +1,956 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Lower allreduce to device implementable ir. + * \file lower_thread_allreduce.cc + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" +#include "tir/transforms/update_pointer_storage_scope.h" + +namespace tvm { +namespace tl { +using namespace tir; +using namespace ffi; + +using runtime::StorageRank; +using runtime::StorageScope; + +/*! + * \brief collect the mapping from the buffer var to its allocate + */ +class AllocateCollector : public StmtExprVisitor { + +private: + bool IsDynamicSharedMemory(Var buffer_var) { + StorageScope storage_scope = runtime::StorageScope::Create( + GetPtrStorageScope(std::move(buffer_var))); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn"; + } + + bool IsStaticSharedMemory(Var buffer_var) { + StorageScope storage_scope = runtime::StorageScope::Create( + GetPtrStorageScope(std::move(buffer_var))); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag.empty(); + } + +public: + void VisitStmt_(const AllocateNode *op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_[op->buffer_var.get()] = op; + } else if (IsStaticSharedMemory(op->buffer_var)) { + static_shmem_allocs_[op->buffer_var.get()] = op; + } + StmtExprVisitor::VisitStmt_(op); + } + // The dynamic mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; + // The static mapping from the original buffer var to its allocate + std::unordered_map + static_shmem_allocs_; +}; + +class ThreadAllreduceBuilder final : public StmtExprMutator { +public: + explicit ThreadAllreduceBuilder(const TargetNode *target, + bool is_dynamic = false) + : target_(target), + warp_size_( + target->GetAttr("thread_warp_size", 1).value().IntValue()), + max_num_threads_(target->GetAttr("max_num_threads", -1) + .value() + .IntValue()) { + if (is_dynamic) { + shared_scope = "shared.dyn"; + } + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + thread_extents_.push_back(op); + Stmt ret = StmtExprMutator::VisitStmt_(op); + thread_extents_.pop_back(); + return ret; + } else if (op->attr_key == tir::attr::reduce_scope) { + const CommReducerNode *combiner = op->node.as(); + ICHECK(combiner); + reduce_combiner_.push_back(combiner); + Stmt ret = StmtExprMutator::VisitStmt_(op); + reduce_combiner_.pop_back(); + return ret; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + Stmt VisitStmt_(const EvaluateNode *op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + const CallNode *call = op->value.as(); + if (call && call->op.same_as(builtin::tvm_thread_allreduce())) { + return MakeAllreduce(call); + } else { + return stmt; + } + } + Stmt VisitStmt_(const AllocateNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + if (auto it = alloc_remap_.find(node->buffer_var.get()); + it != alloc_remap_.end()) { + Buffer buf = Downcast(it->second); + auto write_ptr = node.CopyOnWrite(); + write_ptr->buffer_var = buf->data; + write_ptr->dtype = buf->dtype; + write_ptr->extents = buf->shape; + write_ptr->condition = const_true(buf->dtype.lanes()); + + if (buf.scope() == shared_scope) { + // Use volatile access to shared buffer. + write_ptr->body = + AttrStmt(buf->data, tir::attr::volatile_scope, 1, write_ptr->body); + } + } + return std::move(node); + } + + Optional GetRemappedBuffer(const Buffer &buf) { + if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) { + return it->second; + } + + if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) { + Buffer new_buf = buf; + new_buf.CopyOnWrite()->data = it->second; + buf_remap_[buf.get()] = new_buf; + return new_buf; + } + + return std::nullopt; + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + if (auto buf = GetRemappedBuffer(node->buffer)) { + node.CopyOnWrite()->buffer = buf.value(); + } + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + if (auto it = load_remap_.find(op->buffer->data.get()); + it != load_remap_.end()) { + for (const auto &index : op->indices) { + ICHECK(is_zero(index)) + << "The index of buffer " << op->buffer << " is " << index; + } + return it->second; + } + + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (auto opt = GetRemappedBuffer(load->buffer)) { + load.CopyOnWrite()->buffer = opt.value(); + } + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + if (auto opt = GetRemappedBuffer(store->buffer)) { + store.CopyOnWrite()->buffer = opt.value(); + } + return std::move(store); + } + +private: + // Thread entry + struct ThreadEntry { + runtime::ThreadScope scope; + IterVar iv; + int extent{}; + // comparator + bool operator<(const ThreadEntry &other) const { + return scope.dim_index < other.scope.dim_index; + } + }; + + // make allreduce. + Stmt MakeAllreduce(const CallNode *call) { + ICHECK(!reduce_combiner_.empty()); + const CommReducerNode *combiner = reduce_combiner_.back(); + size_t size = combiner->result.size(); + + const IntImmNode *size_of_args = call->args[0].as(); + ICHECK(size_of_args) << call->args[0]->GetTypeKey(); + ICHECK_EQ(size, size_of_args->value); + Array inits = combiner->identity_element; + std::vector values(size); + std::vector types(size); + PrimExpr cond = call->args[size + 1]; + for (size_t idx = 0; idx < size; ++idx) { + values[idx] = call->args[1 + idx]; + if (!is_one(cond)) { + values[idx] = Select(cond, values[idx], inits[idx]); + } + types[idx] = values[idx].dtype(); + } + std::vector buffers(size); + for (size_t idx = 0; idx < size; ++idx) { + PrimExpr arg = call->args[2 + size + idx]; + // Loads from boolean buffers may have cast nodes inserted by + // earlier passes. + if (auto cast = arg.as()) { + arg = cast->value; + } + buffers[idx] = Downcast(arg)->buffer; + } + + std::unordered_set reduce_set; + for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { + const VarNode *v = call->args[i].as(); + // The simply optimization replace a iteration variable with a constant + // when extent of the iteration is 1. As threaded IterVar always started + // from 0, we can just ignore this variable in this case. + if (v) { + reduce_set.insert(v); + } else { + ICHECK(call->args[i].as() && + call->args[i].as()->value == 0) + << "arg" << i << "should be a VarNode or IntImmNode " + << "while it is " << call->args[i]; + } + } + + size_t nmatch = 0; + std::vector vred, vpar; + int reduce_dim_index = -1; + for (const AttrStmtNode *attr : thread_extents_) { + ThreadEntry e; + IterVar iv = Downcast(attr->node); + e.scope = runtime::ThreadScope::Create(iv->thread_tag); + e.iv = iv; + ICHECK_LE(e.scope.rank, 1); + ICHECK_GE(e.scope.dim_index, 0) + << "vthread do not work with cross thread reduction"; + if (e.scope.rank == 1) { + const auto *ptr = attr->value.as(); + ICHECK(ptr) << "Need constant extent for reduce set " << iv; + e.extent = static_cast(ptr->value); + // ignore variables equal to 0 + if (e.extent == 1) { + continue; + } + + if (reduce_set.count(iv->var.get())) { + bool already_exists = false; + for (const auto &entry : vred) { + if (entry.scope.dim_index == e.scope.dim_index) { + already_exists = true; + break; + } + } + if (!already_exists) { + vred.push_back(e); + ++nmatch; + reduce_dim_index = e.scope.dim_index; + } + } else { + bool already_exists = false; + for (const auto &entry : vpar) { + if (entry.scope.dim_index == e.scope.dim_index) { + already_exists = true; + break; + } + } + if (!already_exists) { + vpar.push_back(e); + } + } + } + } + + // remove reduce thread from parallel thread + if (reduce_dim_index != -1) { + for (size_t i = 0; i < vpar.size(); ++i) { + if (vpar[i].scope.dim_index == reduce_dim_index) { + vpar.erase(vpar.begin() + i); + break; + } + } + } + + ICHECK_EQ(nmatch, reduce_set.size()) + << "Not all reduce index are presented in the context"; + std::sort(vred.begin(), vred.end()); + std::sort(vpar.begin(), vpar.end()); + // the size of each index. + int reduce_extent, group_extent; + PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); + PrimExpr group_index = FlattenThread(vpar, &group_extent); + + // the longest contiguous reduce extent after flattening + int contiguous_reduce_extent = 1; + std::vector> + block_threads; // tuple(dim_index, extent, is_reduce) + for (const ThreadEntry &thr : vred) { + if (thr.scope.rank == 1) { // threadIdx + block_threads.emplace_back(thr.scope.dim_index, thr.extent, true); + } + } + for (const ThreadEntry &thr : vpar) { + if (thr.scope.rank == 1) { // threadIdx + block_threads.emplace_back(thr.scope.dim_index, thr.extent, false); + } + } + // sort according to dim_index + std::sort(block_threads.begin(), block_threads.end()); + for (auto &&thr_attr : block_threads) { + auto [dim_index, extent, is_reduce] = thr_attr; + (void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 + if (is_reduce) { + contiguous_reduce_extent *= extent; + } else { + break; + } + } + + std::vector seq; + std::vector new_alloc_bufs; + // + // This is an optimization. For small reduction sizes, it may be beneficial + // for a single warp to performance the entire reduction. No trips to shared + // memory and no cross warp synchronizations are required. + // The following code emits the reduction as follows: + // + // Allocate reduction vars v[i], i = 0..size-1 + // + // for offset from WARP_SIZE to 1 by 2 + // + // a <- load(v[i]) + // b <- shuffle_down(load(v[i], offset)) + // v[i] <- reduction(a, b) + // + // broadcast results from lane 0 to all other lanes and store + // the final reduction result to the proper location. + // + // When the thread extent is multiple of warp size, we can use a two-stage + // warp-level reduction to optimize. This is implemented by applying the + // algorithm above twice. + // + // For example, suppose we want to use 512 threads to reduce 512 elements + // and the warp size is 32. In this case there are (512 / 32) = 16 warps. + // In the first stage, each of the 16 warps reduces 32 elements. So after + // the stage, we have 16 remaining elements to be reduced, one for each + // warp. We store the 16 elements in shared memory, and start the second + // stage. In the second stage we use the first 16 lanes of the first warp to + // reduce the remaining elements, and this reduction can also be optimized + // by shuffle_down warp-level primitives. + PrimExpr zero_index = make_const(reduce_index->dtype, 0); + + if (IsWarpReduction(types, group_extent, reduce_extent, + contiguous_reduce_extent)) { + std::vector reduce_results; + DataType mask_dtype = DataType::UInt(32); + PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); + + if (reduce_extent <= warp_size_) { + std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce( + values, types, combiner, reduce_index, reduce_extent, group_index, + mask, std::nullopt, &seq); + + // Broadcast the reduction result from lane 0 to all other lanes. + // This avoids to emit predicated stores, as all threads are + // uniformly writing the same result. + for (size_t i = 0; i < size; ++i) { + Buffer buf = Downcast(reduce_results[i])->buffer; + PrimExpr val = BufferLoad(buf, {zero_index}); + ICHECK_EQ(val->dtype, types[i]); + PrimExpr splat = + WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), + val, reduce_extent * group_index); + seq.push_back(BufferStore(buf, splat, {zero_index})); + } + } else { + int n_warps = reduce_extent / warp_size_; + std::vector local_bufs; + + // 1. Create the staging buffer in shared memory. + std::vector staging_shared_bufs; + staging_shared_bufs.reserve(size); + for (size_t i = 0; i < size; ++i) { + Buffer staging_shared_buf = decl_buffer( + /*shape=*/{make_const(reduce_index->dtype, + n_warps * group_extent)}, + /*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging", + /*storage_scope=*/shared_scope); + staging_shared_bufs.push_back(staging_shared_buf); + new_alloc_bufs.push_back(staging_shared_buf); + } + + // 2. First round of allreduce. + std::tie(reduce_results, local_bufs) = + MakeWarpAllreduce(values, types, combiner, reduce_index, warp_size_, + group_index, mask, std::nullopt, &seq); + new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), + local_bufs.end()); + + // 3. Write allreduce results to staging buffer. + std::vector write_staging_buf; + write_staging_buf.reserve(size); + for (size_t i = 0; i < size; ++i) { + new_alloc_bufs.push_back( + Downcast(reduce_results[i])->buffer); + write_staging_buf.push_back(BufferStore( + /*buffer=*/staging_shared_bufs[i], + /*value=*/reduce_results[i], + /*indices=*/ + {group_index * n_warps + floordiv(reduce_index, warp_size_)})); + } + PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index; + seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf))); + seq.push_back(SyncThread(shared_scope)); + + // 4. Load staging buffer. + // Second round of allreduce. + for (size_t i = 0; i < size; ++i) { + values[i] = + BufferLoad(/*buffer=*/staging_shared_bufs[i], + /*indices=*/{group_index * n_warps + reduce_index}); + } + std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( + values, types, combiner, reduce_index, n_warps, group_index, mask, + /*predicate=*/reduce_index < + make_const(reduce_index->dtype, n_warps), + &seq); + new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), + local_bufs.end()); + + // 5. Create shared memory buffer(s) of `group_extent` elements, storing + // the allreduce results so each thread can access. + std::vector write_result; + write_result.reserve(size); + for (size_t i = 0; i < size; ++i) { + new_alloc_bufs.push_back( + Downcast(reduce_results[i])->buffer); + Buffer broadcast_shared_buf = decl_buffer( + /*shape=*/{make_const(reduce_index->dtype, group_extent)}, + /*dtype=*/buffers[i]->dtype, /*name=*/"red_result", + /*storage_scope=*/shared_scope); + write_result.push_back(BufferStore(broadcast_shared_buf, + reduce_results[i], {group_index})); + // Update `reduce_results`, pointing to the value loaded from the + // shared memory buffer. + reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index}); + } + seq.push_back(IfThenElse(reduce_index == zero_index, + SeqStmt::Flatten(write_result))); + seq.push_back(SyncThread(shared_scope)); + } + + // Write back allreduce results and update existing allocations. + for (size_t i = 0; i < size; ++i) { + ICHECK(!load_remap_.count(buffers[i]->data.get())); + PrimExpr pred = const_true(types[i].lanes()); + Buffer buf = Downcast(reduce_results[i])->buffer; + ICHECK_EQ(reduce_results[i]->dtype, types[i]); + load_remap_[buffers[i]->data.get()] = reduce_results[i]; + + auto node = + Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0)); + alloc_remap_[buffers[i]->data.get()] = buf; + var_remap_[buffers[i]->data.get()] = buf->data; + buf_remap_[buffers[i].get()] = buf; + } + } else { + std::vector shared_bufs(size); + if (reduce_extent == 1) { + // special case, no reduction is needed. + std::vector stores; + stores.reserve(size); + for (size_t i = 0; i < size; ++i) { + stores.emplace_back(BufferStore(buffers[i], values[i], {0})); + } + return SeqStmt::Flatten(stores); + } + // This sync is necessary because there might be incomplete read of + // previous iteration on the same buffer. + seq.emplace_back(SyncThread(shared_scope)); + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = decl_buffer( + {IntImm(group_index->dtype, group_extent * reduce_extent)}, + types[idx], "red_buf" + std::to_string(idx), shared_scope); + seq.emplace_back( + BufferStore(shared_bufs[idx], values[idx], + {BufIndex(reduce_index, group_index, reduce_extent)})); + } + seq.emplace_back(SyncThread(shared_scope)); + seq.emplace_back(MakeBufAllreduce( + combiner, types, shared_bufs, reduce_index, group_index, + reduce_extent, group_extent, contiguous_reduce_extent)); + for (size_t idx = 0; idx < size; ++idx) { + ICHECK(!load_remap_.count(buffers[idx]->data.get())); + PrimExpr pred = const_true(types[idx].lanes()); + BufferLoad load(shared_bufs[idx], + {BufIndex(make_zero(reduce_index.dtype()), group_index, + reduce_extent)}); + ICHECK_EQ(load->dtype, types[idx]); + load_remap_[buffers[idx]->data.get()] = load; + alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx]; + var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; + buf_remap_[buffers[idx].get()] = shared_bufs[idx]; + } + } + + // Fix all local allocations as all statements are built. + Stmt body = SeqStmt::Flatten(seq); + for (const Buffer &buf : new_alloc_bufs) { + body = DeclBuffer(buf, body); + body = Allocate(buf->data, buf->dtype, buf->shape, + const_true(buf->dtype.lanes()), body); + } + + return body; + } + + std::pair, std::vector> + MakeWarpAllreduce(std::vector src_values, // + std::vector dtypes, // + const CommReducerNode *combiner, // + const PrimExpr &reduce_index, int reduce_extent, // + const PrimExpr &group_index, // + const PrimExpr &mask, + const Optional &predicate, // + std::vector *seq) { + int n_buffers = src_values.size(); + + std::vector shared_bufs; + std::vector local_bufs; + shared_bufs.reserve(n_buffers); + + // This is the index to the reduction variable, one reduction + // variable per warp. Local scope seems easier to reason without + // relying on a pattern match pass to fix it later. + Array zero_indices = {0}; + Array shape = {1}; + + std::vector load_values; + load_values.reserve(n_buffers); + for (int idx = 0; idx < n_buffers; ++idx) { + shared_bufs.push_back(decl_buffer( + shape, dtypes[idx], "red_buf" + std::to_string(idx), "local")); + load_values.push_back( + BufferStore(shared_bufs[idx], src_values[idx], zero_indices)); + + // Uses a local variable to store the shuffled data. Later + // on, an allocation will be built for this local variable. + local_bufs.push_back( + decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local")); + } + + if (predicate.defined()) { + seq->push_back( + IfThenElse(predicate.value(), SeqStmt::Flatten(load_values))); + } else { + seq->insert(seq->end(), load_values.begin(), load_values.end()); + } + + // The mask for this reducer, as this reducer may sit inside + // a divergent control flow. Here it uses a variable to cache the current + // active channels. + Optional mask_buffer; + if (need_warp_shuffle_mask_) { + mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); + seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices)); + // Push the buffer description. Later this will have an + // allocation built for it. + local_bufs.push_back(mask_buffer.value()); + } + + // Emit reductions within a warp. + int start_offset = 1; + while (start_offset * 2 < reduce_extent) { + start_offset *= 2; + } + for (int offset = start_offset; offset > 0; offset /= 2) { + // Load reduction values, no synchronization needed. + Array a, b; + for (int i = 0; i < n_buffers; ++i) { + const Buffer &shared_buf = shared_bufs[i]; + BufferLoad val(shared_buf, zero_indices); + ICHECK_EQ(val->dtype, dtypes[i]); + a.push_back(val); + + // __shfl_*sync calls shall not appear in if_then_else expressions + // as this is causing extra divergency. E.g. + // + // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); + // + // behaves differently from + // + // int t = __shfl_sync(mask, v1, 0); + // v1 = (v2 < v3) ? v3 : t; + // + // The former may cause dead lock as there is a divergent + // branch with a warp sync call inside. + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), + mask_buffer, val, offset); + const Buffer &local_buf = local_bufs[i]; + Stmt s = BufferStore(local_buf, other, zero_indices); + seq->push_back(s); + + BufferLoad load = BufferLoad(local_buf, zero_indices); + ICHECK_EQ(load->dtype, dtypes[i]); + b.push_back(load); + } + + // Do reductions. + Array ret = (*combiner)(a, b); + + // Store the reduction result to itself. + std::vector stores; + stores.reserve(n_buffers); + for (int i = 0; i < n_buffers; ++i) { + const Buffer &buf = shared_bufs[i]; + stores.push_back(BufferStore(buf, ret[i], zero_indices)); + } + + // During the sub-warp reduction, values from inactive threads could be + // read, which is an undefined behavior according to the cuda document. + // + // In practice, the return value are usually 0, which does no harm to sum + // reduction. However, the result can be incorrect in max or prod + // reduction. Therefore an additional range check has to be performed to + // ensure the correctness. + if (offset * 2 > reduce_extent) { + PrimExpr cond = reduce_index + offset < reduce_extent; + seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores))); + } else { + seq->push_back(SeqStmt::Flatten(stores)); + } + } + + std::vector reduce_results; + reduce_results.reserve(n_buffers); + for (int i = 0; i < n_buffers; ++i) { + reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices)); + } + + return {reduce_results, local_bufs}; + } + + // make allreduce. + Stmt MakeBufAllreduce(const CommReducerNode *combiner, + const std::vector &types, + const Array &shared_bufs, PrimExpr reduce_index, + PrimExpr group_index, int reduce_extent, + int group_extent, int contiguous_reduce_extent) { + // Get next power of two + int reduce_align = 1; + while (reduce_extent > reduce_align) { + reduce_align = reduce_align << 1; + } + ICHECK_GT(reduce_align, 1); + std::vector seq; + + size_t size = shared_bufs.size(); + PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent); + // make reduction + auto fload = [&](int offset) { + Array a, b; + for (size_t i = 0; i < size; ++i) { + BufferLoad b_load( + shared_bufs[i], + {BufIndex(reduce_index + offset, group_index, reduce_extent)}); + ICHECK_EQ(b_load->dtype, types[i]); + b.push_back(b_load); + + BufferLoad a_load(shared_bufs[i], {buf_index}); + ICHECK_EQ(a_load->dtype, types[i]); + a.push_back(a_load); + } + Array ret = (*combiner)(a, b); + return ret; + }; + auto fstore = [&](const Array &ret) { + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index}); + } + return SeqStmt::Flatten(stores); + }; + auto freduce = [&](int offset) { + auto ret = fload(offset); + return fstore(ret); + }; + // Step one, check for + if (reduce_align > reduce_extent) { + // reduction with the boundary condition + reduce_align = reduce_align >> 1; + PrimExpr cond = reduce_index < (reduce_extent - reduce_align); + seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); + seq.emplace_back(SyncThread(shared_scope)); + } + + // normal synchronization + bool warp_align = + group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0; + while (reduce_align > contiguous_reduce_extent || + reduce_align > warp_size_ || !warp_align) { + if (reduce_align == 1) { + break; + } + reduce_align = reduce_align >> 1; + PrimExpr cond = reduce_index < reduce_align; + seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); + seq.emplace_back(SyncThread(shared_scope)); + } + // in warp synchronization. + if (reduce_align > 1) { + PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1); + + std::vector in_warp_seq; + + while (reduce_align > 1) { + reduce_align = reduce_align >> 1; + + // freduce can read/write to the same memory location. For + // example, with reduce_align of 4, threadIdx 3 reads from + // memory location 7 as threadIdx 7 is writing to it. + // Therefore, we need to separate out the load from the store + // with a memory barrier in-between. This isn't necessary for + // the earlier normal synchronization, because those are each + // protected by an if-statement. The if-statement is avoided + // here to reduce thread divergence. + auto loads = fload(reduce_align); + + Array in_warp_local_vars; + for (auto expr : loads) { + Var var("w_" + std::to_string(reduce_align) + "_" + + std::to_string(in_warp_local_vars.size()), + expr->dtype); + in_warp_local_vars.push_back(var); + } + + std::vector in_let_statement; + in_let_statement.emplace_back(SyncThread("warp")); + in_let_statement.emplace_back( + fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()})); + in_let_statement.emplace_back(SyncThread("warp")); + + Stmt body = SeqStmt::Flatten(in_let_statement); + for (size_t i = 0; i < size; i++) { + body = LetStmt(in_warp_local_vars[i], loads[i], body); + } + in_warp_seq.push_back(body); + } + + Stmt warp_body = SeqStmt::Flatten(in_warp_seq); + + seq.emplace_back(IfThenElse(in_warp_cond, warp_body)); + seq.emplace_back(SyncThread(shared_scope)); + } + return SeqStmt::Flatten(seq); + } + // Flatten the thread index. + // Also return a warp number, + PrimExpr FlattenThread(const std::vector &tvec, + int *out_total_extent) { + int &total_extent = *out_total_extent; + total_extent = 1; + if (tvec.empty()) { + return make_zero(DataType::Int(32)); + } + + PrimExpr ret; + for (const ThreadEntry &e : tvec) { + if (ret.defined()) { + ret = ret + e.iv->var * total_extent; + } else { + ICHECK_EQ(total_extent, 1); + ret = e.iv->var; + } + total_extent *= e.extent; + } + return ret; + } + // The local buffer index. + PrimExpr BufIndex(PrimExpr reduce_index, const PrimExpr &group_index, + int reduce_extent) { + if (!is_zero(group_index)) { + return analyzer_.Simplify(group_index * reduce_extent + reduce_index); + } else { + return reduce_index; + } + } + // sync thread op. + static Stmt SyncThread(const std::string &sync) { + return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync)})); + } + + // Emit warp shuffle calls. + PrimExpr WarpShuffle(const Op &op, const Optional &mask_buffer, + const PrimExpr &val, PrimExpr delta_or_lane) { + Array indices = {0}; + PrimExpr mask; + if (mask_buffer.defined()) { + mask = BufferLoad(mask_buffer.value(), indices); + } else { + mask = IntImm(DataType::Int(32), 0); + } + PrimExpr width = IntImm(DataType::Int(32), warp_size_); + Array args{mask, val, std::move(delta_or_lane), width, width}; + return Call(val.dtype(), op, args); + } + + // Check if we can use warp level reduction. + // + // Note: The ROCm backend will only have warp reductions for now. + // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal). + bool IsWarpReduction(const std::vector &types, int group_extent, + int reduce_extent, int contiguous_reduce_extent) { + if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && + (target_->kind->name != "metal")) { + return false; + } + + need_warp_shuffle_mask_ = target_->kind->name != "metal"; + + // rocm only supports 32 bit operands for shuffling at the moment + if ((target_->kind->name == "rocm") && + (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_fixed_length_vector()) + return ty.bits() * ty.lanes() != 32; + return ty.bits() != 32; + }))) { + return false; + } + + // Supported types: + // {u}int, {u}long, {u}long long, float, double, half/half2 + if (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_float16()) + return ty.lanes() > 2; + if (ty.is_fixed_length_vector()) + return true; + return ty.bytes() < 4 || ty.bytes() > 8; + })) { + return false; + } + if (thread_extents_.empty()) { + return false; + } + + // reduce region must be contiguous. + if (contiguous_reduce_extent != reduce_extent) { + return false; + } + + // whether reduce_extent and group_extent are valid for warp reduction. + if (target_->kind->name == "rocm") { + return reduce_extent == warp_size_; + } else { + if (reduce_extent == 1) { + return false; // no need to warp reduce + } else { + bool is_subwarp_reduction = warp_size_ % reduce_extent == 0; + bool is_multiwarp_reduction = + max_num_threads_ != -1 && + max_num_threads_ <= warp_size_ * warp_size_ && + reduce_extent % warp_size_ == 0; + if (is_subwarp_reduction || is_multiwarp_reduction) { + return true; + } else { + return group_extent == 1 && reduce_extent <= warp_size_; + } + } + } + } + + // The target. + const TargetNode *target_ = nullptr; + // The shared scope. + String shared_scope = "shared"; + // The warp size of the device. + int warp_size_{1}; + // The maximum number of threads of the device. "-1" denotes unknown. + int max_num_threads_{-1}; + // A boolean indicating if the target supports warp-level masking. + bool need_warp_shuffle_mask_{}; + + // surrounding scope of thread extent. + std::vector thread_extents_; + std::vector reduce_combiner_; + // The load remap + std::unordered_map load_remap_; + // Allocate remap + std::unordered_map alloc_remap_; + // BufferVar remap + std::unordered_map var_remap_; + // Buffer remap + std::unordered_map buf_remap_; + // Internal analyzer + arith::Analyzer analyzer_; +}; + +namespace transform { +using namespace tir::transform; + +tvm::transform::Pass LowerThreadAllreduce() { + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { + AllocateCollector collector; + collector(f->body); + bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1; + + auto *n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) + << "LowerThreadAllreduce: Require the target attribute"; + const TargetNode *target_node = target.as(); + ThreadAllreduceBuilder thread_all_reduce(target_node, is_dynamic); + n->body = thread_all_reduce(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerThreadAllreduce", + LowerThreadAllreduce); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/lower_tile_op.cc b/tilelang/original/src/transform/lower_tile_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c88e05c56fc0631c760964d0a31bbafc6cc6bdad --- /dev/null +++ b/tilelang/original/src/transform/lower_tile_op.cc @@ -0,0 +1,708 @@ +/*! + * \file lower_tile_op.cc + * \brief Lower the tile op for further codegen. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../layout/layout.h" +#include "../layout/utils.h" +#include "../op/builtin.h" +#include "../op/gemm.h" +#include "../op/gemm_sp.h" +#include "../op/operator.h" + +#include "arith/ir_mutator_with_analyzer.h" +#include "loop_partition.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, + Map &var_remap) { + const auto *ptr_type = + TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); + Type new_type; + // convert fragments to normal local buffer + if (ptr_type->storage_scope == "local.fragment") { + new_type = PointerType(ptr_type->element_type, "local"); + } else { + new_type = buffer->data->type_annotation; + } + Var new_var; + if (ptr_type->storage_scope == "global") { + new_var = buffer->data; + } else { + if (var_remap.count(buffer->data)) { + new_var = var_remap[buffer->data]; + } else { + new_var = Var(buffer->data->name_hint, new_type); + var_remap.Set(buffer->data, new_var); + } + } + Array layout_shape = layout->OutputShape(); + Array output_shape = layout_shape; + + if (ptr_type->storage_scope == "shared" || + ptr_type->storage_scope == "shared.dyn") { + int replicate_extent = 1; + Array buffer_shape = buffer->shape; + int buffer_extent = 1; + int layout_extent = 1; + for (size_t i = 0; i < buffer_shape.size(); i++) { + auto shape = buffer_shape[i].as(); + buffer_extent *= shape->value; + } + for (size_t i = 0; i < layout_shape.size(); i++) { + auto shape = layout_shape[i].as(); + layout_extent *= shape->value; + } + replicate_extent = buffer_extent / layout_extent; + if (replicate_extent > 1) { + output_shape.insert(output_shape.begin(), replicate_extent); + } + } + return Buffer(new_var, buffer->dtype, output_shape, {}, buffer->elem_offset, + buffer->name, buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); +} + +// The function `makeBufferWithLayout` creates a new Buffer object based on the +// given buffer and layout. It handles remapping of buffer variables, adjusts +// the storage scope if needed (e.g., from "local.fragment" to "local"), and +// computes the output shape according to the layout. For shared memory buffers, +// it also handles replication if the buffer's extent is larger than the +// layout's extent. +class LayoutRemapRewriter : public arith::IRMutatorWithAnalyzer { +public: + static Stmt Substitute(Stmt stmt, Map layout_remap) { + arith::Analyzer analyzer; + LayoutRemapRewriter substituter(&analyzer); + substituter.layout_remap_ = std::move(layout_remap); + return substituter.VisitStmt(stmt); + } + +private: + using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; + + Stmt VisitStmt_(const BlockNode *op) final { + auto block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + if (op->annotations.count(attr::kLayoutMap)) { + block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_); + } + return block; + } + + Map layout_remap_; +}; + +/*! + * \brief A class that rewrites buffer references in a statement based on a + * given buffer remapping. + * + * This class is used to update buffer references in a statement after buffer + * transformations have been applied. It specifically handles the remapping of + * padding annotations. + */ +class RemapBufferRewriter : public arith::IRMutatorWithAnalyzer { +public: + /*! + * \brief Substitute buffer references in a statement based on a given buffer + * remapping. \param stmt The statement to rewrite. \param buffer_remap A map + * from old buffers to new buffers. \return The rewritten statement. + */ + static Stmt Substitute(const Stmt &stmt, Map buffer_remap) { + arith::Analyzer analyzer; + RemapBufferRewriter substituter(&analyzer); + substituter.buffer_remap_ = std::move(buffer_remap); + return substituter.VisitStmt(stmt); + } + +private: + using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; + + Stmt VisitStmt_(const BlockNode *op) final { + if (op->annotations.count(attr::kSafeValueMap)) { + return RewritePaddingMap(op); + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + /*! + * \brief Rewrite the padding map annotation of a block. + * \param op The block node to rewrite. + * \return The rewritten block. + */ + Stmt RewritePaddingMap(const BlockNode *op) { + auto safe_value_map = op->annotations.Get(attr::kSafeValueMap); + if (!safe_value_map) { + LOG(FATAL) << "Padding map annotation is missing"; + } + + Map var_remap = CreateVarRemap(); + Map new_safe_value_map = RemapPaddingMap( + Downcast>(safe_value_map.value()), var_remap); + + auto block = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + auto block_ptr = block.CopyOnWrite(); + block_ptr->annotations.Set(attr::kSafeValueMap, new_safe_value_map); + return block; + } + + /*! + * \brief Create a mapping from old variables to new variables based on buffer + * remapping. \return A map from old variables to new variables. + */ + Map CreateVarRemap() const { + Map var_remap; + for (const auto &[buffer, buffer_remap] : buffer_remap_) { + var_remap.Set(buffer->data, buffer_remap->data); + } + return var_remap; + } + + /*! + * \brief Remap the padding map using the variable remapping. + * \param safe_value_map The original padding map. + * \param var_remap The variable remapping. + * \return The remapped padding map. + */ + Map RemapPaddingMap(const Map &safe_value_map, + const Map &var_remap) const { + Map new_safe_value_map; + for (const auto &[var, padding] : safe_value_map) { + if (var_remap.count(var)) { + new_safe_value_map.Set(var_remap.at(var), padding); + } else { + new_safe_value_map.Set(var, padding); + } + } + return new_safe_value_map; + } + + Map buffer_remap_; +}; + +class LowerTileOpPass : arith::IRMutatorWithAnalyzer { +public: + static PrimFunc Substitute(PrimFunc f) { + arith::Analyzer analyzer; + LowerTileOpPass substituter(&analyzer); + // Trace the buffer map for tvm_access_ptr + substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end()); + for (const auto &[_, buffer] : f->buffer_map) { + substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute"; + substituter.target_ = target.value(); + PrimFuncNode *fptr = f.CopyOnWrite(); + fptr->body = substituter.VisitStmt(f->body); + fptr->body = + RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_); + fptr->body = + LayoutRemapRewriter::Substitute(fptr->body, substituter.layout_remap_); + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + Optional opt_disable_tma_lower = + ctxt->GetConfig(kDisableTMALower, Optional()); + + if (!opt_disable_tma_lower.value_or(Bool(false))) { + // @lei: this is a workaround, as if we don't disable tma lower, + // cp async lowering won't be generated. + ctxt->config.Set(kDisableTMALower, Bool(!substituter.has_tma_)); + } + return f; + } + +private: + using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; + + Stmt VisitStmt_(const BlockNode *op) final { + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : op->alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + for (auto buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + Map vmap; + if (op->annotations.count(attr::kLayoutMap)) { + auto layout_map = op->annotations.at(attr::kLayoutMap) + .as>() + .value(); + for (auto [buffer, layout] : layout_map) { + buffer_remap_.Set(buffer, + makeBufferWithLayout(buffer, layout, var_remap_)); + layout_map_.Set(buffer, layout); + } + } + // Begin a new workspace collection frame for this block scope + workspace_stack_.emplace_back(); + + auto block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + auto block_ptr = block.CopyOnWrite(); + for (size_t i = 0; i < block->alloc_buffers.size(); i++) { + auto buffer = block->alloc_buffers[i]; + if (buffer_remap_.count(buffer)) { + block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]); + } + } + // Attach any workspaces requested within this block to its alloc_buffers + if (!workspace_stack_.empty()) { + for (const auto &buffer : workspace_stack_.back()) { + block_ptr->alloc_buffers.push_back(buffer); + } + workspace_stack_.pop_back(); + } + return block; + } + + int CheckAndGetBufferRowSize(const Buffer &buffer) { + CHECK(buffer->shape.size() >= 2) + << "The dimension of Buffer \"" << buffer->name << "\" with shape " + << buffer->shape << " should be at least 2"; + + auto dim = buffer->shape.size(); + auto buffer_row_size = buffer->shape[dim - 1].as()->value; + return buffer_row_size; + } + + struct AccessPtrResult { + PrimExpr expr; + bool rewritten{false}; + }; + + AccessPtrResult + HandleAccessPtrAndOffset(const PrimExpr &access_ptr, + const Optional &offset = std::nullopt, + DataType dtype = DataType::Int(32)) { + AccessPtrResult result{access_ptr, false}; + // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and + // accumulate it to smem_offset + CHECK(access_ptr->IsInstance()) + << "Invalid access ptr for permuted layout: " << access_ptr; + auto access_ptr_call = Downcast(access_ptr); + if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) { + LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet"; + } else if (access_ptr_call->op.same_as(builtin::address_of())) { + Optional resolved = ResolveBufferLoad(access_ptr_call->args[0]); + ICHECK(resolved.defined()) + << "Invalid access op for permuted layout: " << access_ptr; + PrimExpr load_expr = resolved.value(); + if (!load_expr.same_as(access_ptr_call->args[0])) { + auto node = access_ptr_call.CopyOnWrite(); + node->args.Set(0, load_expr); + access_ptr_call = Call(access_ptr_call->dtype, access_ptr_call->op, + {load_expr}, access_ptr_call->span); + } + BufferLoad load = Downcast(access_ptr_call->args[0]); + Array indices = load->indices; + Array old_shape = load->buffer->shape; + + CHECK_EQ(indices.size(), old_shape.size()) + << "Indices size and shape size must match for general N-dimensional " + "buffer " + << "but got indices size: " << indices.size() + << " and shape size: " << old_shape.size(); + + PrimExpr elem_offset = 0; + PrimExpr stride = 1; + + for (int i = static_cast(old_shape.size()) - 1; i >= 0; --i) { + elem_offset += indices[i] * stride; + stride *= old_shape[i]; + } + + PrimExpr smem_offset = + elem_offset + (offset.defined() ? offset.value() : 0); + + Buffer remap_key = FindRemapBuffer(load->buffer).value_or(load->buffer); + Optional layout = FindLayout(remap_key); + if (!layout.defined() || !buffer_map_.count(remap_key->data)) { + return result; + } + auto new_buffer = buffer_remap_.count(remap_key) + ? buffer_remap_[remap_key] + : load->buffer; + auto new_shape = new_buffer->shape; + + auto buffer_map_iter = buffer_map_.find(Downcast(remap_key->data)); + + int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); + (void)buffer_row_size; + + // Convert offset to target-dimension, reindex it and convert it back + Array multi_dim_indices; + PrimExpr remaining_offset = smem_offset; + + for (int i = static_cast(old_shape.size()) - 1; i >= 0; --i) { + multi_dim_indices.insert(multi_dim_indices.begin(), + floormod(remaining_offset, old_shape[i])); + remaining_offset = floordiv(remaining_offset, old_shape[i]); + } + + auto forward_indices = layout.value()->Forward(multi_dim_indices); + PrimExpr new_offset = 0; + PrimExpr stride_offset = 1; + for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { + new_offset += forward_indices[i] * stride_offset; + stride_offset *= new_shape[i]; + } + new_offset = analyzer_->Simplify(new_offset); + + Array new_indices; + for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { + new_indices.insert(new_indices.begin(), + floormod(new_offset, new_shape[i])); + new_offset = floordiv(new_offset, new_shape[i]); + } + + Array new_args = {BufferLoad(new_buffer, new_indices)}; + if (buffer_remap_.count(remap_key)) { + layout_remap_.Set(new_buffer, layout.value()); + } + result.rewritten = true; + result.expr = Call(access_ptr_call->dtype, access_ptr_call->op, new_args, + access_ptr_call->span); + return result; + } else { + LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr; + } + + return result; + } + + Optional ResolveBufferLoad(const PrimExpr &expr) const { + if (expr->IsInstance()) { + return expr; + } + if (const auto *var_node = expr.as()) { + Var var = tvm::ffi::GetRef(var_node); + auto it = let_bindings_.find(var); + if (it != let_bindings_.end()) { + return it->second; + } + } + return Optional(); + } + + Optional FindRemapBuffer(const Buffer &buffer) const { + if (buffer_remap_.count(buffer)) { + return buffer; + } + auto it = buffer_map_.find(buffer->data); + if (it != buffer_map_.end() && buffer_remap_.count(it->second)) { + return it->second; + } + for (const auto &kv : buffer_remap_) { + if (kv.first->data.same_as(buffer->data)) { + return kv.first; + } + if (kv.first->name == buffer->name) { + return kv.first; + } + } + return Optional(); + } + + Optional FindLayout(const Buffer &buffer) const { + if (layout_map_.count(buffer)) { + return layout_map_[buffer]; + } + auto it = buffer_map_.find(buffer->data); + if (it != buffer_map_.end() && layout_map_.count(it->second)) { + return layout_map_[it->second]; + } + for (const auto &kv : layout_map_) { + if (kv.first->data.same_as(buffer->data)) { + return kv.second; + } + if (kv.first->name == buffer->name) { + return kv.second; + } + } + return Optional(); + } + + PrimExpr VisitExpr_(const tir::CallNode *op) final { + if ((!has_tma_) && (op->op.same_as(tl::tma_load()) || + op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_store()))) { + has_tma_ = true; + } + Array ptx_instructions = {builtin::ptx_ldmatrix(), + builtin::mma_store()}; + + if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) == + ptx_instructions.end()) { + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + return call; + } else { + is_ptx_ = true; + } + // Rewrite from/to shared or shared.dyn to/from local + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + if (call->op.same_as(builtin::ptx_ldmatrix())) { + // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset) + // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) + // or T.address_of(buffer, offset) + PrimExpr access_ptr = call->args[5]; + PrimExpr smem_offset = call->args[6]; + Call address_of_call = Downcast(access_ptr); + if (!address_of_call->op.same_as(builtin::address_of())) { + LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr; + } + Optional resolved = ResolveBufferLoad(address_of_call->args[0]); + ICHECK(resolved.defined()) + << "Invalid address_of argument for permuted layout: " + << address_of_call->args[0]; + PrimExpr load_expr = resolved.value(); + if (!load_expr.same_as(address_of_call->args[0])) { + auto call_node = call.CopyOnWrite(); + call_node->args.Set(5, Call(address_of_call->dtype, address_of_call->op, + {load_expr}, address_of_call->span)); + address_of_call = Downcast(call->args[5]); + access_ptr = call->args[5]; + } + BufferLoad load = Downcast(address_of_call->args[0]); + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(5, new_access_ptr.expr); + new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); + } + } else if (call->op.same_as(builtin::mma_store())) { + // because we will directly store result to Buffer instead of calling + // mma_store now + auto access_ptr = call->args[2]; + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(2, new_access_ptr.expr); + } + } else { + LOG(FATAL) << "Invalid call node: " << call; + } + is_ptx_ = false; + return call; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + if (is_ptx_) { + return load; + } + auto buffer = load->buffer; + if (buffer_remap_.count(buffer)) { + auto new_indices = layout_map_[buffer]->Forward(load->indices); + auto new_buffer = buffer_remap_[load->buffer]; + layout_remap_.Set(new_buffer, layout_map_[load->buffer]); + return BufferLoad(new_buffer, new_indices); + } else if (var_remap_.count(buffer->data)) { + auto new_buffer = Buffer( + var_remap_[buffer->data], buffer->dtype, buffer->shape, + buffer->strides, buffer->elem_offset, buffer->name, + buffer->data_alignment, buffer->offset_factor, buffer->buffer_type); + return BufferLoad(new_buffer, load->indices); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + auto buffer = store->buffer; + if (buffer_remap_.count(buffer)) { + auto new_indices = layout_map_[buffer]->Forward(store->indices); + auto new_buffer = buffer_remap_[store->buffer]; + layout_remap_.Set(new_buffer, layout_map_[store->buffer]); + return BufferStore(new_buffer, store->value, new_indices); + } else if (var_remap_.count(buffer->data)) { + auto new_buffer = Buffer( + var_remap_[buffer->data], buffer->dtype, buffer->shape, + buffer->strides, buffer->elem_offset, buffer->name, + buffer->data_alignment, buffer->offset_factor, buffer->buffer_type); + return BufferStore(new_buffer, store->value, store->indices); + } + return store; + } + + PrimExpr VisitExpr_(const VarNode *op) final { + auto var = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + if (buffer_data_to_buffer_.count(var)) { + auto buffer = buffer_data_to_buffer_[var]; + if (buffer_remap_.count(buffer)) + return buffer_remap_[buffer]->data; + } + return var; + } + + Stmt VisitStmt_(const LetStmtNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + bool recorded = false; + if (value->IsInstance()) { + let_bindings_[op->var] = value; + recorded = true; + } + if (SideEffect(value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, value); + } + Stmt body = this->VisitStmt(op->body); + if (recorded) { + let_bindings_.erase(op->var); + } + if (value.same_as(op->value) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + auto n = this->CopyOnWrite(op); + n->value = value; + n->body = body; + return Stmt(n); + } + } + + /** + * @brief Handle an Evaluate node, lowering a detected tile operator to TIR. + * + * This visit implementation detects whether the Evaluate node represents a + * tile operator invocation (via ParseOperator). If no tile operator is found + * or the call targets a global function, the node is delegated to the base + * visitor. + * + * When a tile operator is present, the method: + * - Builds a workspace-allocation callback that creates a dynamic shared + * buffer named "workspace" (storage scope "shared.dyn") and returns its write + * access pointer. + * - Determines thread bounds for lowering from the analyzer's constant-int + * information for thread_var_; if unavailable, a default range [0,1) is + * used. + * - Invokes tile_op->Lower(...) with LowerArgs containing target, thread + * bounds, thread variable, the workspace callback, layout and buffer remap + * maps, and the list of GEMM-involved buffer vars; the analyzer is passed + * through for use during lowering. + * + * The lowered statement returned by the operator is then visited by the base + * IRMutatorWithAnalyzer and that result is returned. + * + * @return Stmt The (possibly transformed) statement after lowering or base + * visitor processing. + */ + Stmt VisitStmt_(const EvaluateNode *op) final { + const CallNode *call = op->value.as(); + // Do not analysis the call node to the global function. + if (call && call->op.as()) + return Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + + auto tile_op = ParseOperator(tvm::ffi::GetRef(op)); + if (!tile_op.defined()) + return IRMutatorWithAnalyzer::VisitStmt_(op); + AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { + auto workspace = + decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn"); + // Record workspace under the innermost block scope so its lifetime + // covers the statements that requested it and does not sink into + // subsequently created inner blocks (e.g., GEMM macro blocks). + if (!workspace_stack_.empty()) { + workspace_stack_.back().push_back(workspace); + } else { + // Fallback: create a temporary frame (should be rare) + workspace_stack_.emplace_back(Array{workspace}); + } + return workspace.access_ptr(2); // write + }; + + Range thread_bounds; + + if (analyzer_->const_int_bound.IsBound(thread_var_->var)) { + auto const_int_bound = analyzer_->const_int_bound(thread_var_); + auto min_value = const_int_bound->min_value; + auto max_value = const_int_bound->max_value; + auto extent = max_value + 1 - min_value; + thread_bounds = + Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value), + IntImm(thread_var_->var.dtype(), extent)); + } else { + thread_bounds = Range::FromMinExtent(0, 1); + } + + // Convert let_bindings_ to Map for LowerArgs + Map let_var_to_expr; + for (const auto &[var, expr] : let_bindings_) { + let_var_to_expr.Set(var, expr); + } + + auto lowered = tile_op->Lower( + LowerArgs{target_, thread_bounds, thread_var_->var, callback, + layout_map_, buffer_remap_, let_var_to_expr}, + analyzer_); + return IRMutatorWithAnalyzer::VisitStmt(lowered); + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + if (iv->thread_tag == "threadIdx.x") { + thread_var_ = iv; + ICHECK(iv->dom->extent.as()); + thread_block_size_ = iv->dom->extent.as()->value; + } + } + return arith::IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Target target_; + Map buffer_data_to_buffer_; + Map layout_map_; + Map layout_remap_; + Map buffer_remap_; + // This is a workaround for cpu backend, + // we need to define a thread_var for the serial loop. + IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), + IterVarType::kDataPar); + size_t thread_block_size_ = 0; + // Stack of per-Block workspace buffers gathered while visiting children + std::vector> workspace_stack_; + // For ptx Node, we need to remap the buffer and indices + // By access CallNode instead of BufferLoad Node. + bool is_ptx_{false}; + std::unordered_map + let_bindings_; + // Mapping from data Var of a Buffer to Buffer, for lookup + std::unordered_map buffer_map_; + Map var_remap_; + bool has_tma_{false}; +}; + +namespace transform { + +using namespace tir::transform; + +tvm::transform::Pass LowerTileOp() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return LowerTileOpPass::Substitute(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp); +} +} // namespace transform + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/make_packed_api.cc b/tilelang/original/src/transform/make_packed_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9e8f76e6163709777bf930dfcf1921ba45b6829 --- /dev/null +++ b/tilelang/original/src/transform/make_packed_api.cc @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file make_packed_api.cc Lower PrimFunc to use the packed function API. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/builtin.h" +#include "arg_binder.h" +#include "merge_if_stmt.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { +using namespace tir; +using namespace ffi; + +namespace { +class ReturnRewriter : public StmtMutator { +public: + explicit ReturnRewriter(Var ret_var) : ret_var_(ret_var) {} + + Stmt VisitStmt_(const ForNode *node) override { + if (node->kind == ForKind::kParallel) + in_parallel_ += 1; + Stmt ret = StmtMutator::VisitStmt_(node); + if (node->kind == ForKind::kParallel) + in_parallel_ -= 1; + return ret; + } + + Stmt VisitStmt_(const EvaluateNode *node) override { + Stmt ret = StmtMutator::VisitStmt_(node); + const EvaluateNode *eval = ret.as(); + ICHECK(eval); + if (const CallNode *call = eval->value.as()) { + if (call->op.same_as(builtin::ret())) { + ICHECK_EQ(in_parallel_, 0) + << "tir.ret cannot be used in parallel scope."; + ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; + ret = WriteToOut(call->args[0]); + } + } + return ret; + } + +private: + struct ConvertedInfo { + int type_index{-1}; + PrimExpr expr; + }; + + ConvertedInfo ConvertForFFI(const PrimExpr &val) { + ConvertedInfo info; + + // convert val's data type to FFI data type, return type code + DataType dtype = val.dtype(); + if (dtype.is_bool()) { + info.type_index = ffi::TypeIndex::kTVMFFIBool; + info.expr = Cast(DataType::Int(64), val); + + } else if (dtype.is_int() || dtype.is_uint()) { + info.type_index = ffi::TypeIndex::kTVMFFIInt; + info.expr = Cast(DataType::Int(64), val); + } else if (dtype.is_float()) { + info.type_index = ffi::TypeIndex::kTVMFFIFloat; + info.expr = Cast(DataType::Float(64), val); + } else if (dtype.is_void()) { + info.type_index = ffi::TypeIndex::kTVMFFINone; + info.expr = val; + } else { + LOG(FATAL) << "data type " << dtype << " not supported yet"; + } + + return info; + } + + Stmt WriteToOut(PrimExpr val) { + auto info = ConvertForFFI(val); + Stmt store_tindex = tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex), + IntImm(DataType::Int(32), info.type_index)})); + Stmt store_zero_padding = tir::Evaluate(tir::Call( + DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding), + IntImm(DataType::Int(32), 0)})); + Stmt store_val = tir::Evaluate(tir::Call( + DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue), + info.expr})); + Stmt ret_zero = Evaluate(tvm::ret(0)); + return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero}); + } + + Var ret_var_; + int in_parallel_{0}; +}; + +class SubroutineCallRewriter : public StmtExprMutator { +public: + static ffi::Optional + Apply(const ffi::Map &packed_func_methods, + Stmt stmt) { + SubroutineCallRewriter rewriter(packed_func_methods); + stmt = rewriter.VisitStmt(stmt); + if (rewriter.made_change_) { + return stmt; + } else { + return std::nullopt; + } + } + +private: + explicit SubroutineCallRewriter( + const ffi::Map &packed_func_methods) + : packed_func_methods(packed_func_methods) {} + + PrimExpr VisitExpr_(const CallNode *op) override { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (auto *gvar_ptr = node->op.as()) { + auto gvar = ffi::GetRef(gvar_ptr); + if (auto symbol = packed_func_methods.Get(gvar)) { + ffi::Array cpacked_args; + cpacked_args.push_back(tir::StringImm(symbol.value())); + for (auto arg : node->args) { + cpacked_args.push_back(arg); + } + + // push an empty handle to be compatible with current cpacked convention + cpacked_args.push_back(tir::make_zero(DataType::Handle())); + made_change_ = true; + return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), + cpacked_args); + } + } + + return node; + } + const ffi::Map &packed_func_methods; + bool made_change_{false}; +}; + +} // namespace + +inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { + return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); +} + +inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { + Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr}); + return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0)); +} + +/* \brief Return the global_symbol of the function, if it should be updated + * + * \param func The function to be inspected + * + * \returns The global_symbol to be used for the function at call + * sites, or std::nullopt if the function is to remain unchanged. + */ +Optional RequiresPackedAPI(const PrimFunc &func) { + // A function with an explicit calling convention has already been + // lowered, and should not be modified. + if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { + if (CallingConv(opt.value()->value) != CallingConv::kDefault) { + return std::nullopt; + } + } + + // Internal function calls do not need the PackedFunc API + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (!global_symbol) { + return std::nullopt; + } + + return global_symbol; +} + +PrimFunc MakePackedAPI(PrimFunc func) { + auto global_symbol = RequiresPackedAPI(func); + if (!global_symbol) { + return func; + } + std::string name_hint = global_symbol.value(); + + Target target = [&]() { + auto opt = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt) << "MakePackedAPI required the function to be annotated with " + "tvm::attr::kTarget (" + << tvm::attr::kTarget + << "), but the function only has attributes " << func->attrs; + return opt.value(); + }(); + int target_device_type = target->GetTargetDeviceType(); + + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + + auto *func_ptr = func.CopyOnWrite(); + // set the global symbol to the packed function name + const Stmt nop = Evaluate(0); + int num_args = static_cast(func_ptr->params.size()); + + // Data field definitions + // The packed fields + Var v_self_handle("self_handle", DataType::Handle()); + Var v_packed_args("args", DataType::Handle()); + Var v_num_packed_args("num_args", DataType::Int(32)); + Var v_result("result", PointerType(PrimType(DataType::Void()))); + + // The device context + Var device_id("dev_id"); + Integer device_type(target_device_type); + // seq_init gives sequence of initialization + // seq_check gives sequence of later checks after init + std::vector seq_init, seq_check, arg_buffer_declarations; + std::unordered_map vmap; + ArgBinder binder(&vmap); + + // --------------------------- + // local function definitions + // load i-th argument as type t + auto f_load_arg_value = [&](DataType arg_type, int i) { + ffi::Array call_args{ + v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; + // load 64 bit version + DataType api_type = APIType(arg_type); + PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); + // cast to the target version. + if (api_type != arg_type) { + res = Cast(arg_type, res); + } + return res; + }; + + // Assert correct type codes for each argument. This must be done + // *before* any initialization steps produced by + // `binder.BindDLTensor()`. The validity of those initialization + // steps depends on the correct types being present, and must not + // occur before the type codes are actually checked. + seq_init.push_back( + MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string { + std::ostringstream error_message; + error_message << name_hint << ": num_args should be " << num_args; + return error_message.str(); + }())); + + if (num_args > 0) { + seq_init.push_back( + MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL")); + } + + // Need to delay binding of the buffers, in case some arguments also + // appear in the buffer. + std::vector> var_def; + std::vector> buffer_def; + + // First, collect a reverse map from Buffer->data var to parameter var so we + // can detect whether a buffer is actually used by the function body. In + // addition, collect variables that appear in the buffer's shape/stride so we + // can consider uses of those symbols as a use of the buffer itself. + std::unordered_map data_var2param; + std::unordered_map> + shape_var2params; + for (const auto &kv : func_ptr->buffer_map) { + const Var ¶m = kv.first; + const Buffer &buf = kv.second; + data_var2param[buf->data.get()] = param.get(); + auto record_shape_vars = [&](const PrimExpr &e) { + PostOrderVisit(e, [&](const ObjectRef &n) { + if (const auto *v = n.as()) { + shape_var2params[v].push_back(param.get()); + } + }); + }; + for (const PrimExpr &e : buf->shape) + record_shape_vars(e); + for (const PrimExpr &e : buf->strides) + record_shape_vars(e); + if (buf->elem_offset.defined()) + record_shape_vars(buf->elem_offset); + } + + // A visitor that records + // - which parameter buffers are used via their data var (load/store/direct), + // - which shape/stride/offset symbols are referenced in the body. + // Shape symbols are not immediately attributed to all carrier buffers here; + // a minimal carrier set is selected after visiting. + struct UsedBufferDetector : public StmtExprVisitor { + UsedBufferDetector( + const std::unordered_map &data2param, + const std::unordered_map> + &shape2params) + : data2param(data2param), shape2params(shape2params) {} + void VisitExpr_(const VarNode *op) override { + auto it = data2param.find(op); + if (it != data2param.end()) { + used_params_by_data.insert(it->second); + } + auto it2 = shape2params.find(op); + if (it2 != shape2params.end()) { + used_shape_vars.insert(op); + } + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const BufferStoreNode *op) override { + auto it = data2param.find(op->buffer->data.get()); + if (it != data2param.end()) { + used_params_by_data.insert(it->second); + } + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode *op) override { + auto it = data2param.find(op->buffer->data.get()); + if (it != data2param.end()) { + used_params_by_data.insert(it->second); + } + StmtExprVisitor::VisitExpr_(op); + } + + const std::unordered_map &data2param; + const std::unordered_map> + &shape2params; + std::unordered_set used_params_by_data; + std::unordered_set used_shape_vars; + }; + + UsedBufferDetector detector(data_var2param, shape_var2params); + detector(func_ptr->body); + + // Build the packed argument handling. While doing so, keep track of whether + // each parameter buffer is actually used. Unused input buffers can be + // nullable and do not require DLTensor field dereferences. + // + // Start from buffers used via data-var (definitely non-NULL), then for each + // referenced shape symbol pick a minimal "carrier" buffer that provides the + // symbol. Prefer carriers that are already used-by-data; otherwise pick one + // arbitrary carrier to ensure the symbol is bound. + std::unordered_set used_param_buffers = + detector.used_params_by_data; + for (const VarNode *sym : detector.used_shape_vars) { + auto it = shape_var2params.find(sym); + if (it == shape_var2params.end()) + continue; + const auto &carriers = it->second; + bool has_used_carrier = false; + for (const VarNode *p : carriers) { + if (used_param_buffers.count(p)) { + has_used_carrier = true; + break; + } + } + // NOTE: With the new nullable shape binding logic in + // ArgBinder::BindDLTensors, we no longer need to force one carrier to be + // non-NULL. The binder will: + // 1. Assert that at least one carrier is non-NULL at runtime + // 2. Use cascaded if_then_else to read from the first non-NULL carrier + // So we can allow all carriers to be nullable. + // if (!has_used_carrier && !carriers.empty()) { + // used_param_buffers.insert(carriers.front()); + // } + } + + for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { + Var param = func_ptr->params[i]; + PrimExpr arg_value; + // type index checks + Var type_index(param->name_hint + ".type_index", DataType::Int(32)); + seq_init.push_back(LetStmt( + type_index, + tir::Call(DataType::Int(32), builtin::tvm_struct_get(), + {v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}), + nop)); + DataType dtype = param.dtype(); + if (dtype.is_handle()) { + std::ostringstream msg; + // Prefer the Buffer name if available; otherwise, fall back to param name + // (trim _handle). + std::string display_name; + auto it_buf = func_ptr->buffer_map.find(param); + if (it_buf != func_ptr->buffer_map.end()) { + const auto &kv = *it_buf; + display_name = kv.second->data->name_hint; + } else { + display_name = param->name_hint; + const char *suffix = "_handle"; + if (display_name.size() >= 7 && + display_name.compare(display_name.size() - 7, 7, suffix) == 0) { + display_name.erase(display_name.size() - 7); + } + } + msg << "kernel " << name_hint << " input " << display_name + << " expected pointer or tensor handle"; + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone || + type_index == ffi::TypeIndex::kTVMFFIOpaquePtr || + type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr || + type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin, + tvm::tir::StringImm(msg.str()), nop)); + // if type_index is Tensor, we need to add the offset of the DLTensor + // header which always equals 16 bytes, this ensures that T.handle always + // shows up as a DLTensor* + const int64_t object_cell_offset = sizeof(TVMFFIObject); + static_assert(object_cell_offset == 24); + arg_value = f_load_arg_value(param.dtype(), i); + PrimExpr handle_from_tensor = + Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), + {arg_value, IntImm(DataType::Int(32), object_cell_offset)}); + arg_value = Select(type_index == ffi::TypeIndex::kTVMFFITensor, + handle_from_tensor, arg_value); + } else if (dtype.is_bool()) { + std::ostringstream msg; + msg << "kernel " << name_hint << " scalar " << param->name_hint + << " expected boolean"; + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool || + type_index == ffi::TypeIndex::kTVMFFIInt, + tvm::tir::StringImm(msg.str()), nop)); + arg_value = + Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i)); + + } else if (dtype.is_int() || dtype.is_uint()) { + std::ostringstream msg; + msg << "kernel " << name_hint << " scalar " << param->name_hint + << " expected integer"; + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt || + type_index == ffi::TypeIndex::kTVMFFIBool, + tvm::tir::StringImm(msg.str()), nop)); + arg_value = f_load_arg_value(param.dtype(), i); + } else { + ICHECK(dtype.is_float()); + std::ostringstream msg; + msg << "kernel " << name_hint << " scalar " << param->name_hint + << " expected float"; + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat || + type_index == ffi::TypeIndex::kTVMFFIInt || + type_index == ffi::TypeIndex::kTVMFFIBool, + tvm::tir::StringImm(msg.str()), nop)); + // use select so we can also handle int conversion to bool + arg_value = tir::Select( + type_index == ffi::TypeIndex::kTVMFFIFloat, + /* true_value = */ f_load_arg_value(param.dtype(), i), + /* false_value = */ + Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i))); + } + var_def.emplace_back(arg_value, param); + if (func_ptr->buffer_map.count(param)) { + // buffer binding now depends on type index + // if the index is Tensor handle, we need to offset to get the DLTensor* + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); + } + } + + // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* + // v_result) + ffi::Array args{v_self_handle, v_packed_args, v_num_packed_args, + v_result}; + + // Arg definitions are defined before buffer binding to avoid the use before + // def errors. + // + // For example, for auto broadcasting, checks are required to guarantee that + // either 0 or the original stride will be correctly used. Checks here have + // to use the args that may have no let binding yet. Therefore, hoisting let + // binding for args before buffer declaration is needed. + for (const auto &[expr, param] : var_def) { + binder.Bind(param, expr, name_hint + "." + param->name_hint, true); + } + + binder.BindDLTensors(buffer_def, device_type, device_id, name_hint, + used_param_buffers); + for (const auto &[var, buffer] : buffer_def) { + // Prefer buffer data var name in diagnostics to avoid exposing low-level + // handle vars + arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); + } + + // reset global symbol to attach prefix + func = WithAttrs( + std::move(func), + {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, + ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); + + Stmt body = ReturnRewriter(v_result)(func_ptr->body); + body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope, + StringImm(name_hint + "_compute_"), body); + // Set device context + if (vmap.count(device_id.get())) { + ffi::Any node = ffi::String("default"); + seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop)); + seq_check.push_back( + AttrStmt(node, tir::attr::device_type, device_type, nop)); + + if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) { + Stmt set_device = + Evaluate(Call(DataType::Int(32), tir::builtin::tvm_call_packed(), + {StringImm(runtime::symbol::tvm_set_device), + device_type, device_id})); + body = SeqStmt({set_device, body}); + } + } + + // Return error code of zero on success + body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + + body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(), + arg_buffer_declarations}, + body); + func_ptr->body = body; + func_ptr->params = args; + + ffi::Array undefined = UndefinedVars(body, func_ptr->params); + + ICHECK_EQ(undefined.size(), 0) + << "In PrimFunc " << name_hint << " variables " << undefined + << " are used, but are not passed in as API arguments"; + + func_ptr->buffer_map = ffi::Map(); + func_ptr->ret_type = PrimType(DataType::Int(32)); + // return the function. + return func; +} + +tvm::transform::Pass MakePackedAPI() { + using tvm::transform::Pass; + auto pass_func = [](IRModule mod, const tvm::transform::PassContext &ctx) { + Map packed_func_methods; + for (const auto &[gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + const auto &prim_func = opt.value(); + if (auto global_symbol = RequiresPackedAPI(prim_func)) { + packed_func_methods.Set(gvar, global_symbol.value()); + } + } + } + + IRModuleNode *mptr = mod.CopyOnWrite(); + IRModule updates; + + for (const auto &[gvar, base_func] : mptr->functions) { + if (auto opt = base_func.as()) { + auto func = opt.value(); + auto orig_func = func; + + if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, + func->body)) { + func.CopyOnWrite()->body = body.value(); + } + func = MakePackedAPI(std::move(func)); + func = MergeIfStmtSubstitute(func); + + if (!func.same_as(orig_func)) { + updates->Add(gvar, func); + } + } + } + + if (!updates->functions.empty()) { + mod.CopyOnWrite()->Update(updates); + } + return mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.MakePackedAPI", + []() { return MakePackedAPI(); }); +} + +} // namespace tl +} // namespace tvm \ No newline at end of file diff --git a/tilelang/original/src/transform/merge_if_stmt.cc b/tilelang/original/src/transform/merge_if_stmt.cc new file mode 100644 index 0000000000000000000000000000000000000000..98d9d3ac22e0e63c214a0b7fe31a4e4776fa413c --- /dev/null +++ b/tilelang/original/src/transform/merge_if_stmt.cc @@ -0,0 +1,138 @@ +/*! + * \file if_stmt_binding.cc + * \brief Merge the If Stmt in SeqStmt + */ + +#include "merge_if_stmt.h" + +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class MergeIfStmtRewriter : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f) { + f.CopyOnWrite()->body = MergeIfStmtRewriter::Apply(f->body); + return f; + } + + static Stmt Apply(Stmt stmt) { + auto rewriter = MergeIfStmtRewriter(); + return rewriter(stmt); + } + +private: + MergeIfStmtRewriter() = default; + + void FlattenAppend(const Stmt &s, Array *out) { + if (const auto *seq = s.as()) { + for (const Stmt &e : seq->seq) { + FlattenAppend(e, out); + } + } else { + out->push_back(s); + } + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + // First, recursively flatten nested SeqStmt so that + // SeqStmt{ if, SeqStmt{ if, SeqStmt{ if } } } + // becomes a single-level sequence of [if, if, if]. + Array flat_seq; + for (const Stmt &stmt : op->seq) { + Stmt new_stmt = this->VisitStmt(stmt); + FlattenAppend(new_stmt, &flat_seq); + } + + // Then, merge consecutive IfThenElse (without else) that share the same + // condition. + Array new_seq; + PrimExpr current_condition; + Array current_if_bodies; + + for (const Stmt &stmt : flat_seq) { + if (const auto *if_node = stmt.as()) { + if (!if_node->else_case.defined()) { + if (current_condition.defined() && + ExprDeepEqual()(current_condition, if_node->condition)) { + current_if_bodies.push_back(if_node->then_case); + continue; + } else { + if (!current_if_bodies.empty()) { + auto if_stmt = + IfThenElse(current_condition, + current_if_bodies.size() == 1 + ? current_if_bodies[0] + : this->VisitStmt(SeqStmt(current_if_bodies)), + Stmt()); + new_seq.push_back(if_stmt); + current_if_bodies.clear(); + } + + current_condition = if_node->condition; + current_if_bodies.push_back(if_node->then_case); + continue; + } + } + } + + if (!current_if_bodies.empty()) { + auto if_stmt = + IfThenElse(current_condition, + current_if_bodies.size() == 1 + ? current_if_bodies[0] + : this->VisitStmt(SeqStmt(current_if_bodies)), + Stmt()); + new_seq.push_back(if_stmt); + current_condition = PrimExpr(); + current_if_bodies.clear(); + } + + new_seq.push_back(stmt); + } + + if (!current_if_bodies.empty()) { + auto if_stmt = + IfThenElse(current_condition, + current_if_bodies.size() == 1 + ? current_if_bodies[0] + : this->VisitStmt(SeqStmt(current_if_bodies)), + Stmt()); + new_seq.push_back(if_stmt); + } + + return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq); + } +}; + +PrimFunc MergeIfStmtSubstitute(PrimFunc &f) { + return MergeIfStmtRewriter::Substitute(f); +} + +Stmt ApplyMergeIfStmt(Stmt stmt) { return MergeIfStmtRewriter::Apply(stmt); } + +using namespace tir::transform; +tvm::transform::Pass MergeIfStmt() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return MergeIfStmtRewriter::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/merge_if_stmt.h b/tilelang/original/src/transform/merge_if_stmt.h new file mode 100644 index 0000000000000000000000000000000000000000..5d7a282d1f7b33c18579da72c2a52cef9574a59d --- /dev/null +++ b/tilelang/original/src/transform/merge_if_stmt.h @@ -0,0 +1,52 @@ +/*! + * \file merge_if_stmt.h + * \brief Merge consecutive If statements with the same condition + */ +#ifndef TVM_TL_TRANSFORM_MERGE_IF_STMT_H_ +#define TVM_TL_TRANSFORM_MERGE_IF_STMT_H_ + +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// Forward declaration +class MergeIfStmtRewriter; + +/*! + * \brief Apply MergeIfStmt transformation to a PrimFunc + * + * This function merges consecutive IfThenElse statements that have the same + * condition into a single if statement with a SeqStmt body. + * + * Example: + * if (cond) { stmt1 } + * if (cond) { stmt2 } + * if (cond) { stmt3 } + * + * Becomes: + * if (cond) { + * stmt1 + * stmt2 + * stmt3 + * } + * + * \param f The PrimFunc to transform + * \return Transformed PrimFunc with merged if statements + */ +PrimFunc MergeIfStmtSubstitute(PrimFunc &f); + +/*! + * \brief Apply MergeIfStmt transformation to a statement + * \param stmt The statement to transform + * \return Transformed statement with merged if statements + */ +Stmt ApplyMergeIfStmt(Stmt stmt); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_MERGE_IF_STMT_H_ diff --git a/tilelang/original/src/transform/merge_shared_memory_allocations.cc b/tilelang/original/src/transform/merge_shared_memory_allocations.cc new file mode 100644 index 0000000000000000000000000000000000000000..55f265083dbfce1c66e1219d45fcba6f7035aeb4 --- /dev/null +++ b/tilelang/original/src/transform/merge_shared_memory_allocations.cc @@ -0,0 +1,1372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file merge_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic or static shared + * memory allocation. This pass merges multiple TIR-level dynamic or static + * shared memory allocations into one allocation. + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../target/utils.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" +#include "tvm/tir/function.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +using runtime::StorageRank; +using runtime::StorageScope; + +static bool IsDynamicSharedMemory(Var buffer_var) { + StorageScope storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn"; +} + +static bool IsStaticSharedMemory(Var buffer_var) { + StorageScope storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag.empty(); +} + +/*! + * \brief collect the mapping from the buffer var to its allocate + */ +class AllocateCollector : public StmtExprVisitor { +public: + void VisitStmt_(const AllocateNode *op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_[op->buffer_var.get()] = op; + } else if (IsStaticSharedMemory(op->buffer_var)) { + static_shmem_allocs_[op->buffer_var.get()] = op; + } + StmtExprVisitor::VisitStmt_(op); + } + // The dynamic mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; + // The static mapping from the original buffer var to its allocate + std::unordered_map + static_shmem_allocs_; +}; + +// Find a linear pattern of storage access +// Used for liveness analysis. +// "linear" means fitting a complex access pattern into an array of StmtEntry +// +// Define "scope" as the body of For/thread_launch/IfThenElse +// Composite scopes(loop/thread_launch/IfThen) is represented by three +// StmtEntry: before_scope -> scope_body -> after_scope +// +// This pass tries to detect last point that we need to keep memory +// alive under the same scope as Allocate. +// The storage need to be kept alive between Allocate and last access. +// The free point is only inserted at the same scope of Allocate. +// +class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { +public: + explicit SharedMemLinearAccessPatternFinder( + bool is_dynamic = true, bool enable_aggressive_merge = false, + bool verbose = false) + : is_dynamic_(is_dynamic), + enable_aggressive_merge_(enable_aggressive_merge), verbose_(verbose) {} + /*! \brief record the touch list of statement. */ + struct StmtEntry { + // The statement + const Object *stmt{}; + // The index in the linear_seq_ to point to end of the nested scope. + // This is only set to non-zero if stmt is a nested scope. + // if offset > 0, means this is the begin, the end entry is current_index + + // offset if offset < 0, means this is the end, the begin entry is + // current_index + offset + int64_t scope_pair_offset{0}; + // The buffer variables this statement touched. + std::vector touched; + }; + // The scope of each allocation + struct AllocEntry { + // the level in the scope stack + size_t level{0}; + // allocation stmt + const AllocateNode *alloc{nullptr}; + }; + + struct StmtAttr { + // the level in the scope stack + size_t level{0}; + }; + + void UpdateStmtAttr(const Object *stmt, size_t level) { + if (stmt_attrs_.find(stmt) == stmt_attrs_.end()) { + stmt_attrs_[stmt] = StmtAttr{level}; + } else { + stmt_attrs_[stmt].level = level; + } + } + + void VisitStmt_(const AllocateNode *op) final { + size_t level = scope_.size(); + const VarNode *buf = op->buffer_var.get(); + // Record the allocation site and depth so liveness can reason about the + // original scope. + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + // Add write access. + const VarNode *buf = op->buffer->data.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + if (IsAppropriateSharedMemory(tvm::ffi::GetRef(buf))) { + // set into scope_.size() - 1 for aggressive memory reuse + auto enable_aggressive_merge = enable_aggressive_merge_; + if (enable_aggressive_merge) { + scope_[scope_.size() - 1].touched.push_back(buf); + } else { + scope_[it->second.level].touched.push_back(buf); + } + } + } + + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (!e.touched.empty()) { + e.stmt = op; + UpdateStmtAttr(op, scope_level_); + linear_seq_.push_back(e); + } + } + + void VisitStmt_(const EvaluateNode *op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (!e.touched.empty()) { + e.stmt = op; + UpdateStmtAttr(op, scope_level_); + linear_seq_.push_back(e); + } + } + + void VisitExpr_(const BufferLoadNode *op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + const VarNode *buf = op->buffer->data.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + // Earlier we required `alloc_level < scope_.size()`, assuming every load + // would occur strictly inside a nested scope. In practice the lowering + // pipeline may materialise reads in the very same frame that owns the + // allocation (e.g. when the buffer value is passed directly to a call), + // which used to trigger the CHECK. Treat same-level accesses as valid so + // the merged allocator can reason about their lifetime correctly. + ICHECK_LE(it->second.level, scope_.size()) + << "Load memory in places other than store."; + if (IsAppropriateSharedMemory(tvm::ffi::GetRef(buf))) { + auto enable_aggressive_merge = enable_aggressive_merge_; + if (enable_aggressive_merge) { + scope_[scope_.size() - 1].touched.push_back(buf); + } else { + // When the access happens in the same scope frame as the allocation + // we attribute it to that frame instead of the outer parent. This + // keeps the liveness window tight while still accounting for nested + // scopes that legitimately touch the buffer deeper in the tree. + size_t access_level = std::min(it->second.level, scope_.size() - 1); + scope_[access_level].touched.push_back(buf); + } + } + } + } + + void VisitExpr_(const VarNode *buf) final { + // Directly reference to the variable count as a read. + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + // Same rationale as the BufferLoad path above: direct references can be + // emitted at the allocation level after flattening, so accept them and + // record the touch for liveness planning. + ICHECK_LE(it->second.level, scope_.size()); + if (IsAppropriateSharedMemory(tvm::ffi::GetRef(buf))) { + auto enable_aggressive_merge = enable_aggressive_merge_; + if (enable_aggressive_merge) { + scope_[scope_.size() - 1].touched.push_back(buf); + } else { + // Attribute same-level uses to the allocation frame, mirroring the + // BufferLoad handling to keep reuse decisions consistent. + size_t access_level = std::min(it->second.level, scope_.size() - 1); + scope_[access_level].touched.push_back(buf); + } + } + } + } + + template void VisitNewScope(const T *op) { + scope_.push_back(StmtEntry()); + StmtEntry e; + e.stmt = op; + UpdateStmtAttr(op, scope_level_); + int64_t begin_index = static_cast(linear_seq_.size()); + // before scope. + linear_seq_.push_back(e); + StmtExprVisitor::VisitStmt_(op); + // after scope. + e.touched = std::move(scope_.back().touched); + scope_.pop_back(); + int64_t end_index = static_cast(linear_seq_.size()); + ICHECK_GT(end_index, begin_index); + // The paired entries serve as scope sentinels once we flatten the + // control-flow tree. + e.scope_pair_offset = begin_index - end_index; + linear_seq_.push_back(e); + // record the pointer to end index. + ICHECK_NE(end_index, 0U); + linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; + } + + void VisitStmt_(const AttrStmtNode *op) final { + // Only record the outer most thread extent. + if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) { + in_thread_env_ = true; + VisitNewScope(op); + in_thread_env_ = false; + } else if (op->attr_key == tir::attr::extern_scope) { + VisitNewScope(op); + } else if (op->attr_key == tir::attr::virtual_thread) { + VisitNewScope(op); + } else if (op->attr_key == "kWarpSpecializationScope") { + IfThenElse body = Downcast(op->body); + this->VisitStmt(body->then_case); + this->VisitStmt(body->else_case.value()); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + + void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); } + + bool ContainsSeqStmt(const Stmt &stmt) { + if (stmt->IsInstance()) { + return true; + } + if (const auto *if_node = stmt.as()) { + return ContainsSeqStmt(if_node->then_case) || + (if_node->else_case.defined() && + ContainsSeqStmt(if_node->else_case.value())); + } + return false; + } + + void VisitStmt_(const ForNode *op) final { + if (ContainsSeqStmt(op->body)) { + scope_level_++; + VisitNewScope(op); + scope_level_--; + } else { + VisitNewScope(op); + } + } + + void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const AssertStmtNode *op) final { VisitNewScope(op); } + + // linearized access sequence. + std::vector linear_seq_; + // The storage scope of each buffer + std::unordered_map alloc_info_; + // The attribute of each statement + std::unordered_map stmt_attrs_; + +private: + // Wrapper function to determine if the shared memory allocation for a + // variable is appropriate. + bool IsAppropriateSharedMemory(const Var &var) { + return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); + } + // Whether do dynamic analysis. + bool is_dynamic_{true}; + // Whether do aggressive merge. + bool enable_aggressive_merge_{false}; + // Whether do verbose logging. + bool verbose_{false}; + // Whether already in thread env. + bool in_thread_env_{false}; + // The scope stack. + std::vector scope_; + // The size of the scope. + size_t scope_level_{0}; +}; + +class SharedMemoryAlignmentPlanner : public StmtExprVisitor { + +public: + static std::unordered_map Plan(const Stmt &stmt) { + SharedMemoryAlignmentPlanner planner; + planner(stmt); + return planner.shmem_alignment_map_; + } + +private: + // Helper to record alignment for a shared/shared.dyn Var under alignment + // scope + void MarkSharedVarIfNeeded(const VarNode *op) { + if (!op || !under_alignment_scope_) + return; + auto ptr_type = op->type_annotation.as(); + if (!ptr_type) + return; + auto scope = GetPtrStorageScope(tvm::ffi::GetRef(op)); + if (scope == "shared" || scope == "shared.dyn") { + auto target = Target::Current(); + ICHECK(target.defined()) << "Target is not defined"; + const int alignment = TargetIsHopper(target) ? 1024 : 16; + shmem_alignment_map_[op] = alignment; + } + } + + void VisitExpr_(const CallNode *op) { + if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) || + op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store()) || + op->op.same_as(tl::initialize_wgmma_descriptor()) || + op->op.same_as(tl::initialize_tcgen05_descriptor())) { + // These intrinsics introduce stricter SMEM alignment requirements; mark + // the subtree. + under_alignment_scope_ = true; + StmtExprVisitor::VisitExpr_(op); + under_alignment_scope_ = false; + } else { + StmtExprVisitor::VisitExpr_(op); + } + } + + void VisitExpr_(const VarNode *op) { + MarkSharedVarIfNeeded(op); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const BufferLoadNode *op) { + // If we encounter address_of(BufferLoad(...)) or any direct BufferLoad + // within an alignment scope, make sure we mark the underlying shared var. + if (op && under_alignment_scope_) { + const VarNode *data_var = op->buffer->data.get(); + MarkSharedVarIfNeeded(data_var); + } + StmtExprVisitor::VisitExpr_(op); + } + + bool under_alignment_scope_{false}; + + std::unordered_map shmem_alignment_map_; +}; + +/*! + * \brief merge the buffers whose live range has no intersection and rewrite the + * body + */ +class SharedMemoryRewriter : public StmtExprMutator { +public: + explicit SharedMemoryRewriter( + const std::unordered_map + &shmem_allocs, + bool is_dynamic = true, bool verbose = false, int align_bytes = 0) + : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs}, verbose_{verbose}, + align_bytes_{align_bytes} { + if (!is_dynamic) { + merged_buf_var_ = + Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared")); + } + } + + /*! + * \brief plan the memory reuse for all the buffer allocated in the statement + * \param stmt the statement + */ + void PlanReuse(const Stmt &stmt, bool is_dynamic = true, + bool enable_aggressive_merge = false, bool verbose = false) { + SharedMemLinearAccessPatternFinder finder(is_dynamic, + enable_aggressive_merge, verbose); + finder(stmt); + shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt); + // First compute liveness over the flattened schedule, then feed it into the + // arena packer. + this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_); + this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_); + } + +private: + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent && !allocated_) { + // Allocate one dynamic shared memory allocation at the beginning of + // thread scope + + if (verbose_) { + + LOG(DEBUG) << "Memory Allocation Plan for " + << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; + LOG(DEBUG) << " Merged Buffer Name: " << merged_buf_var_->name_hint; + LOG(DEBUG) << " Total Merged Size: " << merged_alloc_size_ << " bytes"; + LOG(DEBUG) << " Individual Buffer Allocations:"; + for (const auto &pair : buffer_byte_offsets_) { + const VarNode *buffer_var_node = pair.first; + PrimExpr byte_offset = pair.second; + auto alloc_it = shmem_allocs_.find(buffer_var_node); + if (alloc_it != shmem_allocs_.end()) { + const AllocateNode *alloc = alloc_it->second; + PrimExpr buffer_size_bytes = + alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes(); + LOG(DEBUG) << " Buffer: " << buffer_var_node->name_hint + << " (Type: " << alloc->dtype << ")" + << ", Start Offset: " << byte_offset + << ", Size: " << buffer_size_bytes << " bytes" + << ", End Offset: " + << (byte_offset + buffer_size_bytes - 1); + } else { + LOG(DEBUG) << " Buffer: " << buffer_var_node->name_hint + << ", Start Offset: " << byte_offset + << " (Original allocation info not found)"; + } + } + LOG(DEBUG) << "End of Memory Allocation Plan."; + } + + allocated_ = true; + Allocate new_body(merged_buf_var_, DataType::UInt(8), + {merged_alloc_size_}, const_true(), + StmtExprMutator::VisitStmt(op->body)); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateNode *op) final { + if (IsAppropriateSharedMemory(op->buffer_var)) { + return StmtExprMutator::VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + auto new_buf = GetUpdatedBuffer(node->buffer); + if (!new_buf.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buf; + } + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template Node VisitBufferAccess(Node node) { + if (IsAppropriateSharedMemory(node->buffer->data)) { + ICHECK_EQ(node->indices.size(), 1) + << "MergeSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + Array indices = { + node->indices[0] + + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetUpdatedBuffer(node->buffer); + writer->indices = indices; + } + + return node; + } + + Buffer GetUpdatedBuffer(Buffer buffer) { + auto key = buffer.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + return it->second; + } + + if (IsAppropriateSharedMemory(buffer->data)) { + ICHECK_EQ(buffer->shape.size(), 1) + << "Buffer " << buffer << " has shape " << buffer->shape << ". " + << "MergeSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + auto writer = buffer.CopyOnWrite(); + writer->data = merged_buf_var_; + } + + buffer_remap_[key] = buffer; + return buffer; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + Var buffer = Downcast(op->args[1]); + if (!IsAppropriateSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); + return Call(op->dtype, op->op, + {op->args[0], merged_buf_var_, extra_offset + offset, extent, + op->args[4]}); + } else if (op->op.same_as(builtin::ptx_cp_async())) { + ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); + DataType dtype = op->dtype; + Var buffer = Downcast(op->args[0]); + if (!IsAppropriateSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + PrimExpr offset = this->VisitExpr(op->args[1]); + // the dst shared memory is a byte buffer generated by merging shared + // memory. we need to multiply the offset index by the byte size of the + // original value dtype, to get the correct offset of merged shared + // buffer. + int index_factor = dtype.bytes(); + if (op->args.size() == 5) + return Call(dtype, op->op, + {merged_buf_var_, + mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4]}); + else + return Call(dtype, op->op, + {merged_buf_var_, + mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4], op->args[5]}); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + + PrimExpr GetBufferOffset(const Var &buffer_var, DataType dtype) { + auto it = buffer_byte_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_byte_offsets_.end()) + << "buffer_var = " << buffer_var->name_hint << ", dtype = " << dtype; + return indexdiv(it->second, dtype.bytes() * dtype.lanes()); + } + + // Wrapper function to determine if the shared memory allocation for a + // variable is appropriate. + bool IsAppropriateSharedMemory(const Var &var) { + return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); + } + + using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; + using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr; + + // Metadata about a single shared-memory allocation prior to merging. This + // is used to build lifetimes, alignment requirements, and final offsets. + struct BufInfo { + const VarNode *var{nullptr}; + std::string name; + PrimExpr size_expr; + std::optional const_size_bytes; // in bytes if compile-time known. + int alignment{0}; // required byte alignment. + int start{0}; // first statement index touching the buf. + int end{0}; // one-past-last statement index. + DataType size_dtype{DataType::Int(32)}; + }; + + // Interval describing the liveness window of a (constant-sized) allocation. + struct Interval { + int start{0}; + int end{0}; + size_t size_bytes{0}; + int alignment{0}; + const VarNode *var{nullptr}; + }; + + // Result of a linear-scan arena packing. Offsets contain the byte offset for + // each constant-sized buffer, arena_size is the total constant footprint. + struct ArenaPlan { + size_t arena_size{0}; + std::unordered_map offsets; + }; + + static size_t AlignUpSize(size_t value, size_t alignment) { + if (alignment == 0) { + return value; + } + size_t remainder = value % alignment; + if (remainder == 0) { + return value; + } + return value + (alignment - remainder); + } + + struct FreeBlock { + size_t offset{0}; + size_t size{0}; + }; + + class FreeList { + public: + std::optional Allocate(size_t need, size_t alignment) { + // Best-fit search: pick the slot that wastes the least space after + // alignment. + int best = -1; + size_t best_waste = std::numeric_limits::max(); + for (int i = 0, n = static_cast(blocks_.size()); i < n; ++i) { + size_t aligned = AlignUpSize(blocks_[i].offset, alignment); + size_t head = aligned - blocks_[i].offset; + if (head <= blocks_[i].size && (blocks_[i].size - head) >= need) { + size_t waste = blocks_[i].size - head - need; + if (waste < best_waste) { + best_waste = waste; + best = i; + } + } + } + if (best < 0) { + return std::nullopt; + } + FreeBlock blk = blocks_[best]; + size_t aligned = AlignUpSize(blk.offset, alignment); + size_t head = aligned - blk.offset; + size_t tail = blk.size - head - need; + blocks_.erase(blocks_.begin() + best); + if (head) { + blocks_.push_back({blk.offset, head}); + } + if (tail) { + blocks_.push_back({aligned + need, tail}); + } + Normalize(); + return aligned; + } + + void Free(size_t offset, size_t size) { + if (size == 0) + return; + blocks_.push_back({offset, size}); + Normalize(); + } + + private: + void Normalize() { + if (blocks_.empty()) + return; + std::sort(blocks_.begin(), blocks_.end(), + [](const FreeBlock &a, const FreeBlock &b) { + return a.offset < b.offset; + }); + std::vector merged; + merged.reserve(blocks_.size()); + for (const FreeBlock &blk : blocks_) { + if (merged.empty()) { + merged.push_back(blk); + continue; + } + FreeBlock &last = merged.back(); + size_t last_end = last.offset + last.size; + if (blk.offset <= last_end) { + size_t blk_end = blk.offset + blk.size; + if (blk_end > last_end) { + last.size = blk_end - last.offset; + } + } else { + merged.push_back(blk); + } + } + blocks_ = std::move(merged); + } + + std::vector blocks_; + }; + + struct ActiveInterval { + int end{0}; + size_t offset{0}; + size_t size{0}; + const VarNode *var{nullptr}; + bool operator>(const ActiveInterval &other) const { + return end > other.end; + } + }; + + static ArenaPlan LinearScanPack(std::vector intervals) { + // Process intervals in program order so lifetimes correspond to the + // linearised CFG. + std::sort(intervals.begin(), intervals.end(), + [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + if (lhs.size_bytes != rhs.size_bytes) { + return lhs.size_bytes > rhs.size_bytes; + } + return lhs.var < rhs.var; + }); + + std::priority_queue, + std::greater> + active; + FreeList freelist; + size_t arena_top = 0; + std::unordered_map offsets; + + // Expire intervals that end before or at program counter `pc`. + auto retire = [&](int pc) { + while (!active.empty() && active.top().end <= pc) { + const ActiveInterval top = active.top(); + active.pop(); + freelist.Free(top.offset, top.size); + } + }; + + for (const Interval &interval : intervals) { + retire(interval.start); + size_t offset = 0; + // Try to recycle previously freed memory first; fall back to bumping the + // arena. + if (auto slot = + freelist.Allocate(interval.size_bytes, interval.alignment)) { + offset = slot.value(); + } else { + offset = AlignUpSize(arena_top, interval.alignment); + arena_top = offset + interval.size_bytes; + } + active.push(ActiveInterval{interval.end, offset, interval.size_bytes, + interval.var}); + offsets[interval.var] = offset; + } + + return ArenaPlan{arena_top, std::move(offsets)}; + } + + PrimExpr AlignPrimExpr(const PrimExpr &value, int alignment) const { + if (alignment <= 1) { + return value; + } + DataType dtype = value.dtype(); + ICHECK(dtype.is_int() || dtype.is_uint()) + << "Expected integer dtype for alignment, but got " << dtype; + PrimExpr align_expr = make_const(dtype, alignment); + PrimExpr adjust = make_const(dtype, alignment - 1); + return indexdiv(value + adjust, align_expr) * align_expr; + } + + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + + void PlanAlignment(const Stmt &stmt) { + DLOG(INFO) << "PlanAlignment"; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (const auto *call = node.as()) { + if (call->op.same_as(tl::tl_gemm()) || + call->op.same_as(tl::tl_gemm_sp())) { + DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: " + << call->op; + } + } + }); + } + /*! + * \brief Liveness analysis to find gen and kill point of each variable. + * \param seq the linear pattern of storage access + */ + void LivenessAnalysis( + const std::vector &seq, + const std::unordered_map &stmt_attrs) { + // find kill point, do a reverse linear scan. + std::unordered_set touched; + for (size_t i = seq.size(); i != 0; --i) { + const StmtEntry &s = seq[i - 1]; + for (const VarNode *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].kill.push_back(buffer); + } + } + } + // find gen point, do forward scan + touched.clear(); + for (size_t i = 0; i < seq.size(); ++i) { + int64_t offset = seq[i].scope_pair_offset; + if (offset < 0) + continue; + const StmtEntry &s = seq[i + offset]; + for (const VarNode *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].gen.push_back(buffer); + } + } + } + + if (verbose_) { + std::vector stmt_keys; + for (const auto &stmt_entry : seq) { + auto stmt = stmt_entry.stmt; + if (std::find(stmt_keys.begin(), stmt_keys.end(), stmt) == + stmt_keys.end()) { + stmt_keys.push_back(stmt); + } + } + LOG(DEBUG) << "Before reorder kill points, Liveness Analysis Results for " + << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; + for (const auto &stmt_key : stmt_keys) { + auto it = event_map_.find(stmt_key); + if (it == event_map_.end()) + continue; + + const EventEntry &entry = it->second; + if (entry.gen.empty() && entry.kill.empty()) + continue; + ICHECK(stmt_attrs.count(stmt_key)) + << "stmt_key = " << stmt_key->GetTypeKey(); + auto level = stmt_attrs.at(stmt_key).level; + LOG(DEBUG) << " Statement: " << stmt_key->GetTypeKey() + << " (scope_level: " << level << ")"; + + std::stringstream gen_vars_ss; + bool x_generated = false; + for (const VarNode *var : entry.gen) { + gen_vars_ss << var->name_hint << " "; + if (var->name_hint == "x") { + x_generated = true; + } + } + if (!entry.gen.empty()) { + std::string gen_log_msg = " GEN: " + gen_vars_ss.str(); + if (x_generated) { + gen_log_msg += " <-- Buffer 'x' generated"; + } + LOG(DEBUG) << gen_log_msg; + } + + std::stringstream kill_vars_ss; + bool x_killed = false; + for (const VarNode *var : entry.kill) { + kill_vars_ss << var->name_hint << " "; + if (var->name_hint == "x") { + x_killed = true; + } + } + if (!entry.kill.empty()) { + std::string kill_log_msg = " KILL: " + kill_vars_ss.str(); + if (x_killed) { + kill_log_msg += " <-- Buffer 'x' killed"; + } + LOG(DEBUG) << kill_log_msg; + } + } + LOG(DEBUG) << "End of Liveness Analysis Results."; + } + + // Reorder kill points: + // For each buffer, if its kill statement is at a deeper scope level than + // its gen statement, we need to move the kill point to the end of the gen + // statement's scope level. This ensures proper memory deallocation at the + // right scope boundary. + std::vector gen_kill_seq; + for (const auto &stmt_entry : seq) { + // if has gen and kill, add to gen_kill_seq + if (!event_map_[stmt_entry.stmt].gen.empty() || + !event_map_[stmt_entry.stmt].kill.empty()) { + gen_kill_seq.push_back(stmt_entry); + } + } + + for (auto &event_pair : event_map_) { + const Object *stmt = event_pair.first; + EventEntry &event = event_pair.second; + + // Skip if no kill points to process + if (event.kill.empty()) + continue; + + // Get scope level of current statement + ICHECK(stmt_attrs.count(stmt)); + int kill_level = stmt_attrs.at(stmt).level; + + std::unordered_set visited_buffers; + + // For each killed buffer, find its gen statement and check scope levels + for (auto it = event.kill.begin(); it != event.kill.end();) { + const VarNode *buffer = *it; + bool found_gen = false; + int gen_level = 0; + + // Find the gen statement for this buffer + for (const auto &gen_pair : event_map_) { + const auto &gen_event = gen_pair.second; + if (std::find(gen_event.gen.begin(), gen_event.gen.end(), buffer) != + gen_event.gen.end()) { + found_gen = true; + gen_level = stmt_attrs.at(gen_pair.first).level; + break; + } + } + + if (found_gen && kill_level > gen_level) { + if (visited_buffers.count(buffer)) { + ++it; + continue; + } + // Need to move kill point - remove from current event + it = event.kill.erase(it); + + // Find the last statement at gen_level and add kill point there + // Find the last statement at gen_level in the sequence + const Object *last_stmt_at_level = nullptr; + auto stmt_it = gen_kill_seq.begin(); + for (; stmt_it != gen_kill_seq.end(); ++stmt_it) { + if (stmt_it->stmt == stmt) { + break; + } + } + // start from current statement and find the last statement at + // gen_level + + for (; stmt_it != gen_kill_seq.end(); ++stmt_it) { + // Check if next statement has different level + auto next_it = stmt_it + 1; + if (next_it == gen_kill_seq.end() || + stmt_attrs.at(next_it->stmt).level == gen_level) { + last_stmt_at_level = stmt_it->stmt; + break; + } + } + if (last_stmt_at_level) { + event_map_[last_stmt_at_level].kill.push_back(buffer); + visited_buffers.insert(buffer); + } + } else { + ++it; + } + } + } + + std::vector stmt_keys; + for (const auto &stmt_entry : seq) { + auto stmt = stmt_entry.stmt; + if (std::find(stmt_keys.begin(), stmt_keys.end(), stmt) == + stmt_keys.end()) { + stmt_keys.push_back(stmt); + } + } + + if (verbose_) { + LOG(DEBUG) << "Liveness Analysis Results for " + << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; + for (const auto &stmt_key : stmt_keys) { + auto it = event_map_.find(stmt_key); + if (it == event_map_.end()) + continue; + + const EventEntry &entry = it->second; + if (entry.gen.empty() && entry.kill.empty()) + continue; + ICHECK(stmt_attrs.count(stmt_key)) + << "stmt_key = " << stmt_key->GetTypeKey(); + auto level = stmt_attrs.at(stmt_key).level; + LOG(DEBUG) << " Statement: " << stmt_key->GetTypeKey() + << " (scope_level: " << level << ")"; + + std::stringstream gen_vars_ss; + bool x_generated = false; + for (const VarNode *var : entry.gen) { + gen_vars_ss << var->name_hint << " "; + if (var->name_hint == "x") { + x_generated = true; + } + } + if (!entry.gen.empty()) { + std::string gen_log_msg = " GEN: " + gen_vars_ss.str(); + if (x_generated) { + gen_log_msg += " <-- Buffer 'x' generated"; + } + LOG(DEBUG) << gen_log_msg; + } + + std::stringstream kill_vars_ss; + bool x_killed = false; + for (const VarNode *var : entry.kill) { + kill_vars_ss << var->name_hint << " "; + if (var->name_hint == "x") { + x_killed = true; + } + } + if (!entry.kill.empty()) { + std::string kill_log_msg = " KILL: " + kill_vars_ss.str(); + if (x_killed) { + kill_log_msg += " <-- Buffer 'x' killed"; + } + LOG(DEBUG) << kill_log_msg; + } + } + LOG(DEBUG) << "End of Liveness Analysis Results."; + } + } + + /*! + * \brief Memory plan algorithm + * \param seq the linear pattern of storage access + * \param alloc_info + */ + void + PlanMemory(const std::vector &seq, + const std::unordered_map &stmt_attrs) { + buffer_byte_offsets_.clear(); + (void)stmt_attrs; + + if (shmem_allocs_.empty()) { + merged_alloc_size_ = make_const(DataType::Int(64), 0); + return; + } + + // Discover the first and last touch for every allocation. + std::unordered_map start_index; + std::unordered_map end_index; + + for (size_t i = 0; i < seq.size(); ++i) { + auto it = event_map_.find(seq[i].stmt); + if (it == event_map_.end()) + continue; + for (const VarNode *var : it->second.gen) { + start_index.emplace(var, static_cast(i)); + } + for (const VarNode *var : it->second.kill) { + end_index[var] = std::max(end_index[var], static_cast(i) + 1); + } + } + + const int seq_len = static_cast(seq.size()); + for (const auto &kv : start_index) { + if (!end_index.count(kv.first)) { + end_index[kv.first] = seq_len; + } + } + + std::vector buf_infos; + buf_infos.reserve(shmem_allocs_.size()); + // Build a BufInfo for all allocations that participate in liveness. + for (const auto &kv : shmem_allocs_) { + const VarNode *var = kv.first; + auto start_it = start_index.find(var); + if (start_it == start_index.end()) { + continue; + } + + BufInfo info; + info.var = var; + info.name = var->name_hint; + info.start = start_it->second; + info.end = std::max(end_index[var], info.start + 1); + info.alignment = align_bytes_; + auto align_it = shmem_alignment_map_.find(var); + if (align_it != shmem_alignment_map_.end()) { + info.alignment = std::max(info.alignment, align_it->second); + } + + const AllocateNode *alloc = kv.second; + int64_t bytes_per_elem = + static_cast(alloc->dtype.bytes() * alloc->dtype.lanes()); + DataType size_dtype = DataType::Int(32); + if (!alloc->extents.empty()) { + size_dtype = alloc->extents[0].dtype(); + } + if (!size_dtype.is_int() && !size_dtype.is_uint()) { + size_dtype = DataType::Int(32); + } + + PrimExpr size_expr = make_const(size_dtype, bytes_per_elem); + for (const PrimExpr &extent : alloc->extents) { + PrimExpr e = extent; + if (e.dtype() != size_dtype) { + e = cast(size_dtype, e); + } + size_expr = size_expr * e; + } + info.size_dtype = size_dtype; + info.size_expr = size_expr; + + int64_t const_extent = alloc->ConstantAllocationSize(); + if (const_extent >= 0) { + info.const_size_bytes = const_extent * bytes_per_elem; + } + + buf_infos.push_back(std::move(info)); + } + + // Stable order so the later passes have deterministic behaviour. + std::sort(buf_infos.begin(), buf_infos.end(), + [](const BufInfo &a, const BufInfo &b) { + if (a.start != b.start) + return a.start < b.start; + if (a.end != b.end) + return a.end < b.end; + return a.name < b.name; + }); + + std::vector intervals; + intervals.reserve(buf_infos.size()); + for (const BufInfo &info : buf_infos) { + if (!info.const_size_bytes.has_value()) + continue; + // Only constant-sized buffers participate in the arena packing because + // dynamic sizes must be placed sequentially later. + Interval interval; + interval.start = info.start; + interval.end = info.end; + interval.size_bytes = static_cast( + std::max(0, info.const_size_bytes.value())); + interval.alignment = info.alignment; + interval.var = info.var; + intervals.push_back(interval); + } + + ArenaPlan plan = LinearScanPack(std::move(intervals)); + size_t arena_size_const = plan.arena_size; + + if (verbose_) { + LOG(DEBUG) << "ArenaPlan (constant buffers): arena_size=" + << arena_size_const; + for (const auto &kv : plan.offsets) { + const VarNode *var = kv.first; + LOG(DEBUG) << " " << var->name_hint << " -> offset=" << kv.second; + } + } + + // Cursor tracks the running byte offset within the merged arena. + DataType offset_dtype = + buf_infos.empty() ? DataType::Int(32) : buf_infos.front().size_dtype; + PrimExpr total_size = make_const(offset_dtype, 0); + PrimExpr cursor = AlignPrimExpr( + make_const(offset_dtype, static_cast(arena_size_const)), + align_bytes_); + + auto CastToOffset = [&](PrimExpr expr) -> PrimExpr { + if (expr.dtype() == offset_dtype) { + return expr; + } + return cast(offset_dtype, expr); + }; + + for (const BufInfo &info : buf_infos) { + PrimExpr offset_expr; + auto it = plan.offsets.find(info.var); + if (it != plan.offsets.end()) { + offset_expr = + make_const(offset_dtype, static_cast(it->second)); + } else { + // Dynamic-sized buffers are appended after the constant arena. + cursor = AlignPrimExpr(cursor, info.alignment); + PrimExpr size_expr = CastToOffset(info.size_expr); + offset_expr = cursor; + cursor = offset_expr + size_expr; + } + + buffer_byte_offsets_[info.var] = offset_expr; + PrimExpr buf_end = offset_expr + CastToOffset(info.size_expr); + total_size = max(total_size, buf_end); + } + + merged_alloc_size_ = buf_infos.empty() + ? make_const(offset_dtype, 0) + : AlignPrimExpr(total_size, align_bytes_); + + bool overlap_detected = false; + + if (verbose_) { + LOG(DEBUG) << "Memory Allocation Plan for " + << (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:"; + LOG(DEBUG) << " Total Merged Size (aligned): " << merged_alloc_size_; + for (const BufInfo &info : buf_infos) { + const PrimExpr &offset = buffer_byte_offsets_.at(info.var); + LOG(DEBUG) << " Buffer: " << info.name << " start=" << info.start + << " end=" << info.end << " alignment=" << info.alignment + << " offset=" << offset << " size=" << info.size_expr; + } + // Sanity check for overlapping constant buffers. + for (size_t i = 0; i < buf_infos.size(); ++i) { + const BufInfo &a = buf_infos[i]; + auto a_off_imm = buffer_byte_offsets_.at(a.var).as(); + if (!a.const_size_bytes.has_value() || a_off_imm == nullptr) + continue; + int64_t a_off = a_off_imm->value; + int64_t a_end = a_off + a.const_size_bytes.value(); + for (size_t j = i + 1; j < buf_infos.size(); ++j) { + const BufInfo &b = buf_infos[j]; + auto b_off_imm = buffer_byte_offsets_.at(b.var).as(); + if (!b.const_size_bytes.has_value() || b_off_imm == nullptr) + continue; + bool live_overlap = !(a.end <= b.start || b.end <= a.start); + if (!live_overlap) + continue; + int64_t b_off = b_off_imm->value; + int64_t b_end = b_off + b.const_size_bytes.value(); + bool mem_overlap = !(a_end <= b_off || b_end <= a_off); + if (mem_overlap) { + overlap_detected = true; + LOG(WARNING) << "Buffer overlap detected between " << a.name + << " and " << b.name << " (lifetime overlap with " + << "offset ranges [" << a_off << ", " << a_end + << ") and [" << b_off << ", " << b_end << "))."; + } + } + } + } + + if (overlap_detected) { + LOG(WARNING) << "Detected overlapping constant buffers; falling back to " + << "sequential allocation without reuse."; + buffer_byte_offsets_.clear(); + // In the fallback path we simply lay buffers out sequentially. + PrimExpr new_cursor = make_const(offset_dtype, 0); + PrimExpr new_total = make_const(offset_dtype, 0); + for (const BufInfo &info : buf_infos) { + new_cursor = AlignPrimExpr(new_cursor, info.alignment); + PrimExpr size_expr = CastToOffset(info.size_expr); + buffer_byte_offsets_[info.var] = new_cursor; + PrimExpr buf_end = new_cursor + size_expr; + new_total = max(new_total, buf_end); + new_cursor = buf_end; + } + merged_alloc_size_ = buf_infos.empty() + ? make_const(offset_dtype, 0) + : AlignPrimExpr(new_total, align_bytes_); + } + } + + // Whether enable dynamic analysis. + bool is_dynamic_{true}; + + // Whether enable verbose logging. + bool verbose_{false}; + // The alignment bytes for the merged buffer + int align_bytes_{16}; + // The var for the merged buffer + Var merged_buf_var_{"buf_dyn_shmem", + PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; + // The mapping from the original buffer var to its allocate + std::unordered_map shmem_allocs_; + // The size of the merged buffer + PrimExpr merged_alloc_size_{0}; + // The mapping from the original buffer var to its offset in the merged buffer + std::unordered_map buffer_byte_offsets_; + // The mapping from the original buffer objects to their location in the + // merged buffer. + std::unordered_map buffer_remap_; + // The flag indicating whether the merged buffer has been allocated + bool allocated_{false}; + // Locations of free ops. + std::unordered_map event_map_; + // The mapping of buffer bytes alignment + std::unordered_map shmem_alignment_map_; +}; + +Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem, + bool enable_aggressive_merge, + int align_bytes = 16, bool verbose = false) { + AllocateCollector collector; + collector(stmt); + if (collector.dyn_shmem_allocs_.size() > 1) { + SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose, + align_bytes); + rewriter.PlanReuse(stmt, true, enable_aggressive_merge); + stmt = rewriter(std::move(stmt)); + } + if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) { + SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false, + verbose, align_bytes); + rewriter.PlanReuse(stmt, false, enable_aggressive_merge); + stmt = rewriter(std::move(stmt)); + } + return stmt; +} + +using namespace tir::transform; + +namespace transform { + +Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false, + int align_bytes = 16) { + auto pass_func = [enable_aggressive_merge, align_bytes]( + PrimFunc f, const IRModule &m, PassContext ctx) { + bool merge_static_smem = + ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + bool debug_merge_shared_memory_allocations = + ctx->GetConfig(kDebugMergeSharedMemoryAllocations, Bool(false)) + .value(); + auto *n = f.CopyOnWrite(); + n->body = tl::MergeSharedMemoryAllocations( + std::move(n->body), merge_static_smem, enable_aggressive_merge, + align_bytes, debug_merge_shared_memory_allocations); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations", + {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations", + MergeSharedMemoryAllocations); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/multi_version_buffer_rewriter.cc b/tilelang/original/src/transform/multi_version_buffer_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ed9437cffd8faa5e4c22fbd7eae90987a1700f8 --- /dev/null +++ b/tilelang/original/src/transform/multi_version_buffer_rewriter.cc @@ -0,0 +1,503 @@ +/*! + * \file warp_specialized_pipeline.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +enum class Role : uint8_t { kConsumer, kProducer, kBoth }; + +class WarpSpecializedRoleMarker_ : public StmtVisitor { +public: + WarpSpecializedRoleMarker_(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + + Role GetRole(const StmtNode *stmt) const { + auto it = map_.find(stmt); + ICHECK(it != map_.end()); + return it->second; + } + + Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); } + + void VisitStmt_(const EvaluateNode *op) final { + Role role = Role::kConsumer; + if (auto call = op->value.as()) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + role = Role::kProducer; + has_bulk_copy_ = true; + } + } + SetRole(op, role); + } + + void VisitStmt_(const BufferStoreNode *op) final { + bool is_shared_store = + op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; + if (!is_shared_store) { + SetRole(op, Role::kConsumer); + return; + } + + // Check reads from global + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ tvm::ffi::GetRef(op)); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto reads = access[0]; + Role role = Role::kProducer; + for (auto read : reads) { + if (read->buffer.scope() != "global") { + role = Role::kConsumer; + break; + } + } + if (role == Role::kProducer) + has_simt_copy_ = true; + SetRole(op, role); + } + + void VisitStmt_(const SeqStmtNode *op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->seq[0]); + for (auto stmt : op->seq) { + if (role != GetRole(stmt)) { + role = Role::kBoth; + break; + } + } + SetRole(op, role); + } + + void VisitStmt_(const IfThenElseNode *op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->then_case); + if (op->else_case.defined()) { + auto role_else = GetRole(op->else_case.value()); + if (role != role_else) + role = Role::kBoth; + } + SetRole(op, role); + } + + void VisitStmt_(const BlockRealizeNode *op) final { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->block)); + } + + template void HandleBodyStmt(const NodeType *op) { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->body)); + } + + void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); } + + bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + + bool HasSimtCopy() { return has_simt_copy_; } + +private: + void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; } + Map buffer_data_to_buffer_; + std::unordered_map map_; + bool has_simt_copy_ = false; + bool has_bulk_copy_ = false; +}; + +class MultiVersionBufferRewriter : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f) { + auto rewriter = MultiVersionBufferRewriter(); + rewriter.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : rewriter.buffer_lca_) { + Var buffer_var = buffer->data; + rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer); + } + f.CopyOnWrite()->body = rewriter(f->body); + return f; + } + +private: + MultiVersionBufferRewriter() = default; + + Array GetVersionedBuffers(const Array &seq_stmt, + const Array &scoped_buffers) { + Array pipeline_stmts; + std::function collect_stmts = [&](const Stmt &stmt) { + if (const auto *seq = stmt.as()) { + for (const Stmt &s : seq->seq) { + collect_stmts(s); + } + return; + } + if (const auto *let = stmt.as()) { + collect_stmts(let->body); + return; + } + if (const auto *attr = stmt.as()) { + collect_stmts(attr->body); + return; + } + if (const auto *block_realize = stmt.as()) { + collect_stmts(block_realize->block->body); + return; + } + if (const auto *block = stmt.as()) { + collect_stmts(block->body); + return; + } + pipeline_stmts.push_back(stmt); + }; + for (const Stmt &stmt : seq_stmt) { + collect_stmts(stmt); + } + + std::vector roles; + Array> reads, writes; + auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_); + for (const Stmt &stmt : pipeline_stmts) { + marker(stmt); + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", /*body*/ stmt); + auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); + reads.push_back(access[0]); + writes.push_back(access[1]); + roles.push_back(marker.GetRole(stmt)); + } + + std::unordered_set consumer_used, producer_used; + std::unordered_map first_write_index; + std::unordered_map last_read_index; + auto is_copy_stage = [&](size_t idx) { + bool has_shared_write = false; + for (const BufferRegion &wr : writes[idx]) { + auto scope = wr->buffer.scope(); + if (scope == "shared" || scope == "shared.dyn") { + has_shared_write = true; + break; + } + } + if (!has_shared_write) + return false; + for (const BufferRegion &rd : reads[idx]) { + if (rd->buffer.scope() == "global") { + return true; + } + } + return false; + }; + for (size_t i = 0; i < pipeline_stmts.size(); i++) { + bool copy_stage = is_copy_stage(i); + bool is_producer = roles[i] == Role::kProducer || + (roles[i] == Role::kBoth && copy_stage); + bool is_consumer = roles[i] == Role::kConsumer || + (roles[i] == Role::kBoth && !copy_stage); + if (is_producer) { + for (BufferRegion br : writes[i]) { + producer_used.insert(br->buffer.get()); + } + } + if (is_consumer) { + for (BufferRegion br : reads[i]) { + consumer_used.insert(br->buffer.get()); + } + } + for (BufferRegion br : writes[i]) { + const BufferNode *buf = br->buffer.get(); + if (!first_write_index.count(buf)) { + first_write_index[buf] = i; + } + } + for (BufferRegion br : reads[i]) { + last_read_index[br->buffer.get()] = i; + } + } + Array versioned_buffers; + for (Buffer buffer : scoped_buffers) { + if (consumer_used.count(buffer.get()) && + producer_used.count(buffer.get())) { + versioned_buffers.push_back(buffer); + continue; + } + // Fallback: if we saw a write before a later read, the buffer spans + // multiple stages even if role classification missed one side. + auto it_w = first_write_index.find(buffer.get()); + auto it_r = last_read_index.find(buffer.get()); + if (it_w != first_write_index.end() && it_r != last_read_index.end() && + it_w->second < it_r->second) { + if (!is_copy_stage(it_w->second)) + continue; + versioned_buffers.push_back(buffer); + } + } + return versioned_buffers; + } + + static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { + ObjectPtr new_buffer = + tvm::ffi::make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); + if (!new_buffer->strides.empty()) { + ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); + } + return Buffer(new_buffer); + } + + Stmt VisitStmt_(const BlockRealizeNode *op) final { + BlockRealize block_realize = + Downcast(StmtExprMutator::VisitStmt_(op)); + Block block = block_realize->block; + Array alloc_buffers; + for (auto buffer : block->alloc_buffers) { + if (buffer_remap_.count(buffer)) { + Buffer new_buffer = buffer_remap_[buffer]; + alloc_buffers.push_back(new_buffer); + } else { + alloc_buffers.push_back(buffer); + } + } + block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); + // Record the updated alloc list to recover buffers whose LCA is the block. + block_alloc_buffers_[op->block.get()] = block->alloc_buffers; + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + + Stmt VisitStmt_(const BlockNode *op) final { + stmt_stack_.push_back(op); + Stmt stmt = StmtExprMutator::VisitStmt_(op); + stmt_stack_.pop_back(); + return stmt; + } + + Stmt VisitStmt_(const ForNode *op) final { + stmt_stack_.push_back(op); + loop_stack_.emplace_back(op->loop_var, op->extent); + auto num_stages_anno = op->annotations.Get("num_stages"); + if (!num_stages_anno) { + auto for_node = StmtExprMutator::VisitStmt_(op); + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + return for_node; + } + + ICHECK(num_stages_anno->as()); + int num_stages = static_cast(num_stages_anno->as()->value); + + Stmt pipeline_body_root{nullptr}; + if (const auto *realize = op->body.as()) { + const auto &block = realize->block; + for (const auto &buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body_root = block->body; + } else { + pipeline_body_root = op->body; + } + + const SeqStmtNode *pipeline_body_seq = nullptr; + { + // Traverse trivial wrappers (let/if) to find the actual SeqStmt body. + Stmt current = pipeline_body_root; + while (true) { + if (const auto *seq_stmt = current.as()) { + pipeline_body_seq = seq_stmt; + break; + } + if (const auto *if_then_else = current.as()) { + ICHECK(!if_then_else->else_case.defined()) + << "MultiVersionBuffer: Can't handle the body of the loop " + "because the IfThenElse node has an else branch"; + current = if_then_else->then_case; + continue; + } + if (const auto *let_stmt = current.as()) { + current = let_stmt->body; + continue; + } + LOG(FATAL) + << "MultiVersionBuffer: Can't handle the body of the loop because " + << "it is not a SeqStmt, IfThenElse without else, " + << "or LetStmt wrapping them, but got " << current->GetTypeKey(); + } + } + ICHECK(pipeline_body_seq != nullptr); + + Array scoped_buffers; + std::unordered_set seen; + for (auto [buffer, stmt] : buffer_lca_) { + if (!stmt.defined()) + continue; + const StmtNode *lca = stmt.value().get(); + bool in_scope = false; + for (const StmtNode *ancestor : stmt_stack_) { + if (ancestor == lca) { + in_scope = true; + break; + } + } + if (!in_scope) + continue; + // Only double-buffer shared allocations; locals do not need versioning. + auto scope = buffer.scope(); + if (!(scope == "shared" || scope == "shared.dyn")) + continue; + if (seen.insert(buffer.get()).second) { + scoped_buffers.push_back(buffer); + } + } + for (auto it = stmt_stack_.rbegin(); it != stmt_stack_.rend(); ++it) { + if (!(*it)->IsInstance()) + continue; + const auto *block = static_cast(*it); + auto map_it = block_alloc_buffers_.find(block); + if (map_it == block_alloc_buffers_.end()) + continue; + for (const Buffer &buffer : map_it->second) { + auto scope = buffer.scope(); + if (!(scope == "shared" || scope == "shared.dyn")) + continue; + if (seen.insert(buffer.get()).second) { + scoped_buffers.push_back(buffer); + } + } + } + + Array versioned_buffers = + GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers); + + for (auto buffer : versioned_buffers) { + Var buffer_var = buffer->data; + Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); + buffer_remap_.Set(buffer, new_buffer); + } + PrimExpr linear_index = loop_stack_[0].first; + for (size_t i = 1; i < loop_stack_.size(); ++i) { + linear_index = + linear_index * loop_stack_[i].second + loop_stack_[i].first; + } + version_index_ = FloorMod(linear_index, num_stages); + auto for_node = StmtExprMutator::VisitStmt_(op); + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + + return for_node; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return std::move(load); + } + const Buffer &new_buffer = (*it).second; + auto *n = load.CopyOnWrite(); + n->buffer = new_buffer; + n->indices.insert(n->indices.begin(), version_index_); + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return std::move(store); + } + const Buffer &new_buffer = (*it).second; + auto *n = store.CopyOnWrite(); + n->buffer = new_buffer; + n->indices.insert(n->indices.begin(), version_index_); + return std::move(store); + } + + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(builtin::tvm_access_ptr())) { + return RewriteBufferAccess(call, {1}); + } + return call; + } + + PrimExpr RewriteBufferAccess(const Call &call, + const std::vector &arg_indices) { + auto product = [](const Array &input) { + return foldl( + [](PrimExpr a, PrimExpr b, Span span) { + return mul(std::move(a), std::move(b), std::move(span)); + }, + make_const(DataType::Int(32), 1), input); + }; + Array new_args = call->args; + for (int i : arg_indices) { + auto buffer_var = Downcast(call->args[i]); + if (!buffer_data_to_buffer_.count(buffer_var)) + continue; + const Buffer &buffer = buffer_data_to_buffer_[buffer_var]; + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + const Buffer &new_buffer = (*it).second; + const PrimExpr &old_index = call->args[i + 1]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = product(buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + version_index_ * offset; + new_args.Set(i + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } + + PrimExpr version_index_; + std::vector> loop_stack_; + // Track ancestor statements to query whether an LCA is inside the current + // loop. + std::vector stmt_stack_; + Map buffer_data_to_buffer_; + Map> buffer_lca_; + Map buffer_remap_; + // Remember each block's alloc list so the loop can see buffers defined in + // parents. + std::unordered_map> block_alloc_buffers_; +}; + +using namespace tir::transform; + +tvm::transform::Pass MultiVersionBuffer() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return MultiVersionBufferRewriter::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/persist_threadblock.cc b/tilelang/original/src/transform/persist_threadblock.cc new file mode 100644 index 0000000000000000000000000000000000000000..b64ffdcce85ab1bfdbfec86ede2ae9ba599dceac --- /dev/null +++ b/tilelang/original/src/transform/persist_threadblock.cc @@ -0,0 +1,68 @@ +/*! + * \file lower_l2_persistent_annotation.cc + * \brief Lower L2 persistent annotation + */ + +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../runtime/runtime.h" + +namespace tvm { +namespace tl { + +namespace attr { +// BlockAttr, Containing the layout for all the buffers in the block +constexpr const char *kUseCooperativeGroups = "use_cooperative_groups"; +} // namespace attr + +using namespace tir; + +class PersistThreadblock : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f) { + PrimFuncNode *fptr = f.CopyOnWrite(); + PersistThreadblock substituter; + // Trace the buffer map for tvm_access_ptr + fptr->body = substituter.VisitStmt(f->body); + if (substituter.has_sync_grid_) { + f = WithAttr(std::move(f), attr::kUseCooperativeGroups, + IntImm(DataType::Int(32), 1)); + } + return f; + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + if (const auto *call = op->value.as()) { + if (call->op.same_as(sync_grid())) { + has_sync_grid_ = true; + } + } + return StmtExprMutator::VisitStmt_(op); + } + +private: + PersistThreadblock() = default; + bool has_sync_grid_ = false; +}; + +using namespace tir::transform; + +tvm::transform::Pass PersistThreadblock() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return PersistThreadblock::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/pipeline_planning.cc b/tilelang/original/src/transform/pipeline_planning.cc new file mode 100644 index 0000000000000000000000000000000000000000..717dce27f7e1b2641f9d20e57e6c156fa07c23d6 --- /dev/null +++ b/tilelang/original/src/transform/pipeline_planning.cc @@ -0,0 +1,737 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include +#include + +#include "../target/utils.h" +#include "tvm/ir/expr.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Check whether two regions have intersections. + * \param region1 The first region. + * \param region2 The second region. + * \return Whether region1 and region2 have intersections. + */ +bool MayConflict(const Region ®ion1, const Region ®ion2) { + ICHECK(region1.size() == region2.size()); + for (size_t i = 0; i < region1.size(); i++) { + Range dim1 = region1[i]; + Range dim2 = region2[i]; + auto int_set1 = arith::IntSet::FromRange(dim1); + auto int_set2 = arith::IntSet::FromRange(dim2); + if (arith::Intersect({int_set1, int_set2}).IsNothing()) { + return false; + } + } + return true; +} + +class TmemLoadCollector : public StmtExprVisitor { +public: + TmemLoadCollector() {} + + Buffer result; + +private: + void VisitExpr_(const BufferLoadNode *op) { + Buffer buf = op->buffer; + if (buf->data->type_annotation.as()->storage_scope == + "shared") { + // We only care about shared.tmem buffers + ICHECK(!result.defined()) + << "TmemLoadCollector: More than one shared buffer visited"; + result = buf; + } + } +}; + +/*! + * \brief Build the dependency chain between async operations and their + * corresponding buffers & synchronizations. + * + * Example: + * If we encounter the following pattern: + * + * tcgen5mma_gemm_ts(..., mbar, ...) + * mbarrier_wait_parity(mbar) + * + * The builder will link the mbarrier to the buffers used in the + * TCGEN5MMA + */ +class AsyncDependencyChainBuilder : public StmtExprVisitor { +public: + AsyncDependencyChainBuilder(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(buffer_data_to_buffer) {} + + std::unordered_map> + mbar_to_buffer_reads_; + + std::unordered_map> + mbar_to_buffer_writes_; + +private: + Map buffer_data_to_buffer_; + + void VisitExpr_(const CallNode *op) final { + auto args = op->args; + if (op->op.same_as(builtin::call_extern())) { + std::string func_name_with_template = args[0].as()->value; + std::size_t le_pos = func_name_with_template.find_first_of('<'); + std::string func_name = le_pos == std::string::npos + ? func_name_with_template + : func_name_with_template.substr(0, le_pos); + // TODO(lei): refactor to use identical ops. + if (func_name == "tl::tcgen5mma_gemm_ts" || + func_name == "tl::tcgen5mma_gemm_ss") { + // TCGEN5MMA + auto get_buf_from_access_ptr_call = + [&](const PrimExpr &expr) -> Buffer { + auto call = expr.as(); + ICHECK(call); + ICHECK(call->op.same_as(builtin::tvm_access_ptr())); + auto var = call->args[1].as(); + ICHECK(var); + auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef(var)); + ICHECK(it != buffer_data_to_buffer_.end()); + return (*it).second; + }; + Buffer a_buf = get_buf_from_access_ptr_call(args[1]); + Buffer b_buf = get_buf_from_access_ptr_call(args[2]); + Buffer mbar_buf = get_buf_from_access_ptr_call(args[4]); + + TmemLoadCollector tmem_collector; + tmem_collector(args[3]); + ICHECK(tmem_collector.result.defined()) + << "TmemLoadCollector: No tmem buffer load found in the TCGEN5MMA " + "call"; + Buffer c_buf = tmem_collector.result; + + PrimExpr clear_accum = args[5]; + mbar_to_buffer_reads_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(a_buf)); + mbar_to_buffer_reads_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(b_buf)); + mbar_to_buffer_writes_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(c_buf)); + auto analyzer = std::make_shared(); + if (!analyzer->CanProveEqual(clear_accum, Bool(true))) { + mbar_to_buffer_reads_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(c_buf)); + } + } + // TODO (lei) Link wgmma to buffers and tl.wait_wgmma + } else if (op->op.same_as(tir::builtin::if_then_else())) { + const PrimExpr &then_expr = args[1]; + const PrimExpr &else_expr = args[2]; + this->VisitExpr(then_expr); + this->VisitExpr(else_expr); + } else { + StmtExprVisitor::VisitExpr_(op); + } + } +}; + +/*! + * \brief Detect if a statement follows the global memory copy pattern: + * 1. Contains exactly one buffer store operation + * 2. Source buffer must be in global memory scope + * 3. Destination buffer must be in local or shared memory scope + */ +class BufferRegionCollector : public StmtExprVisitor { +public: + BufferRegionCollector(Map buffer_data_to_buffer, + const AsyncDependencyChainBuilder &chain_builder) + : buffer_data_to_buffer_(buffer_data_to_buffer), + chain_builder_(chain_builder) {} + + Array GetReads() const { return reads_; } + + Array GetWrites() const { return writes_; } + + bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; } + +private: + void VisitStmt_(const BufferStoreNode *op) final { + Buffer store_buffer = op->buffer; + Array indices = op->indices; + // convert indices to region + Array region; + for (const auto &index : indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + auto store_region = BufferRegion(store_buffer, region); + writes_.push_back(store_region); + + is_global_read_ = false; + this->VisitExpr(op->value); + if (is_global_read_ && (store_buffer.scope() == "shared" || + store_buffer.scope() == "shared.dyn")) { + is_global_copy_pattern_ = true; + } + is_global_read_ = false; + } + + void VisitExpr_(const BufferLoadNode *op) final { + auto load_buffer = op->buffer; + Array indices = op->indices; + // convert indices to region + Array region; + for (const auto &index : indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + auto load_region = BufferRegion(load_buffer, region); + reads_.push_back(load_region); + + if (op->buffer.scope() == "global" && !within_condition_expr_) { + // skip condition expr of if_then_else node + // shared[i] = T.if_then_else(global[i] < n, register_a[i], register_b[i]) + // is not a global read shared[i] = T.if_then_else(global[i] < n, + // global_a[i], global_b[i]) is a global read + is_global_read_ = true; + } + } + + void VisitExpr_(const CallNode *op) final { + auto args = op->args; + if (op->op.same_as(builtin::address_of())) { + BufferRegion buffer_region; + if (const auto *load = op->args[0].as()) { + buffer_region = BufferRegion::FullRegion(load->buffer); + } else if (const auto *var_node = op->args[0].as()) { + Var data_var = tvm::ffi::GetRef(var_node); + auto it = buffer_data_to_buffer_.find(data_var); + if (it != buffer_data_to_buffer_.end()) { + buffer_region = BufferRegion::FullRegion((*it).second); + } + } + if (buffer_region.defined()) { + // because we only care about the buffer itself instead of indices + reads_.push_back(buffer_region); + } + } else if (op->op.same_as(builtin::tvm_access_ptr())) { + const VarNode *buffer_var = op->args[1].as(); + ICHECK(buffer_var); + auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef(buffer_var)); + if (it != buffer_data_to_buffer_.end()) { + const Buffer &buffer = (*it).second; + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + // because we only care about the buffer itself instead of indices + reads_.push_back(buffer_region); + } + } else if (op->op.same_as(builtin::if_then_else())) { + within_condition_expr_ = true; + this->VisitExpr(op->args[0]); + within_condition_expr_ = false; + for (auto i = 1; i < op->args.size(); i++) { + this->VisitExpr(op->args[i]); + } + } else if (op->op.same_as(tl::mbarrier_wait_parity())) { + ICHECK(args[0].as()); + Buffer mbar_buf = args[0].as()->buffer; + auto buffer_reads = + chain_builder_.mbar_to_buffer_reads_.find(mbar_buf.get()); + auto buffer_writes = + chain_builder_.mbar_to_buffer_writes_.find(mbar_buf.get()); + if (buffer_reads != chain_builder_.mbar_to_buffer_reads_.end()) { + reads_.insert(reads_.end(), buffer_reads->second.begin(), + buffer_reads->second.end()); + } + if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { + writes_.insert( + writes_.end(), + chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), + chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + } + } else { + StmtExprVisitor::VisitExpr_(op); + } + } + + void VisitStmt_(const IfThenElseNode *op) final { + within_condition_expr_ = true; + this->VisitExpr(op->condition); + within_condition_expr_ = false; + this->VisitStmt(op->then_case); + if (op->else_case.defined()) { + within_condition_expr_ = true; + this->VisitStmt(op->else_case.value()); + within_condition_expr_ = false; + } + } + +private: + AsyncDependencyChainBuilder chain_builder_; + Map buffer_data_to_buffer_; + Array reads_; + Array writes_; + bool is_global_read_ = false; + bool under_buffer_store_ = false; + bool is_global_copy_pattern_ = false; + bool within_condition_expr_ = false; +}; + +class PipelinePlanner : public StmtExprMutator { +public: + static Stmt Substitute(const PrimFunc &f, bool use_async_copy = true) { + PipelinePlanner substituter(use_async_copy); + for (const auto &[_, buffer] : f->buffer_map) { + substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) + << "Pipeline_Planning: Require the target attribute"; + substituter.target_ = target.value(); + return substituter.VisitStmt(f->body); + } + +private: + PipelinePlanner() = default; + PipelinePlanner(bool use_async_copy) : use_async_copy_(use_async_copy) {} + + /*! \brief Information about a pipeline stage + * + * \param reads Array of buffer regions read by this stage + * \param writes Array of buffer regions written by this stage + * \param original_stmt_index Original position of this stage in the pipeline + * before reordering \param order Current position of this stage in the + * pipeline after reordering (-1 if not yet assigned) \param stage Pipeline + * stage number this operation belongs to (-1 if not yet assigned) \param + * copy_stage Whether this stage is a memory copy operation \param + * last_use_stmt_index Index of the last statement (in original order) that + * uses the results of this stage (-1 if not yet determined). This field is + * crucial for pipeline optimization: + * - For copy stages: indicates the index of the last statement that reads + * from the copied data, helping determine optimal placement of copy + * operations + * - Used to ensure copy operations are scheduled before their consumers + * - A value of -1 means no subsequent statement uses this stage's output + * - This information enables better pipeline scheduling by minimizing data + * dependencies and maximizing parallelism + */ + struct PipelineStageInfo { + Array reads, writes; + int original_stmt_index{}; + int order = -1, stage = -1; + bool copy_stage = false; + bool producer_for_copy = false; + int last_use_stmt_index = + -1; // Initialized to -1, indicating no consumers found yet + + public: + bool is_first_stage() const { return copy_stage || producer_for_copy; } + bool is_copy_stage() const { return copy_stage; } + bool is_producer_for_copy() const { return producer_for_copy; } + bool is_last_use_stmt_index_valid() const { + return last_use_stmt_index != -1; + } + }; + + PipelineStageInfo + MakePipelineStageInfo(Stmt stmt, int idx, + AsyncDependencyChainBuilder &chain_builder) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ std::move(stmt)); + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto collector = + BufferRegionCollector(buffer_data_to_buffer_, chain_builder); + collector(block); + PipelineStageInfo pinfo; + pinfo.reads = std::move(collector.GetReads()); + pinfo.writes = std::move(collector.GetWrites()); + pinfo.original_stmt_index = idx; + pinfo.copy_stage = collector.GetGlobalCopyPattern(); + return std::move(pinfo); + } + + Stmt VisitStmt_(const ForNode *loop) final { + auto order_anno = loop->annotations.Get("tl_pipeline_order"); + auto stage_anno = loop->annotations.Get("tl_pipeline_stage"); + auto num_stages_anno = loop->annotations.Get("num_stages"); + if (order_anno && stage_anno) { + // Check if order_anno or stage_anno contains -1, which means TMA+WS is + // enabled + bool ws_tma_enabled = false; + auto order_array = Downcast>(order_anno.value()); + auto stage_array = Downcast>(stage_anno.value()); + for (const auto &val : order_array) { + if (val->value == -1) { + ws_tma_enabled = true; + break; + } + } + if (!ws_tma_enabled) { + for (const auto &val : stage_array) { + if (val->value == -1) { + ws_tma_enabled = true; + break; + } + } + } + + if (ws_tma_enabled) { + return StmtExprMutator::VisitStmt_(loop); + } + + Map annotations; + for (const auto &[key, value] : loop->annotations) { + if (key != "tl_pipeline_order") { + annotations.Set(key, value); + } + } + annotations.Set(tir::attr::software_pipeline_order, order_anno.value()); + + for (const auto &[key, value] : loop->annotations) { + if (key != "tl_pipeline_stage") { + annotations.Set(key, value); + } + } + annotations.Set(tir::attr::software_pipeline_stage, stage_anno.value()); + if (TargetHasAsyncCopy(target_) && use_async_copy_) + annotations.Set(tir::attr::software_pipeline_async_stages, + Array{0}); + auto for_node = tvm::ffi::GetRef(loop); + for_node.CopyOnWrite()->annotations = annotations; + return for_node; + } + + if (!num_stages_anno) + return StmtExprMutator::VisitStmt_(loop); + int num_stages = num_stages_anno->as()->value; + Stmt pipeline_body_root{nullptr}; + if (const auto *realize = loop->body.as()) { + const auto &block = realize->block; + for (const auto &buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body_root = block->body; + } else { + pipeline_body_root = loop->body; + } + const SeqStmtNode *pipeline_body_seq = nullptr; + { + Stmt current = pipeline_body_root; + while (true) { + if (const auto *seq_stmt = current.as()) { + pipeline_body_seq = seq_stmt; + break; + } + if (const auto *if_then_else = current.as()) { + ICHECK(!if_then_else->else_case.defined()) + << "Pipeline_Planning: Can't handle the body of the loop because " + "the IfThenElse node has an else branch"; + current = if_then_else->then_case; + continue; + } + if (const auto *let_stmt = current.as()) { + current = let_stmt->body; + continue; + } + LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop " + << "because it is not a SeqStmt, IfThenElse without else, " + << "or LetStmt wrapping them, but got " + << current->GetTypeKey(); + } + } + ICHECK(pipeline_body_seq != nullptr); + + CHECK(num_stages >= 1); + CHECK(loop->kind == ForKind::kSerial); + + AsyncDependencyChainBuilder chain_builder(buffer_data_to_buffer_); + chain_builder(pipeline_body_root); + + std::vector pipeline_stage_infos; + for (size_t i = 0; i < pipeline_body_seq->size(); i++) { + auto pinfo = + MakePipelineStageInfo(pipeline_body_seq->seq[i], i, chain_builder); + pipeline_stage_infos.push_back(std::move(pinfo)); + } + + // For every copy stage, mark all its dependency stages as producer_for_copy + // Helper struct to manage copy stage dependency reads + struct CopyStageDependencyReadsManager { + std::vector regions; + + // Add a region if not already present (by structural equality) + void AddUnique(const BufferRegion ®ion) { + for (const BufferRegion ©_read : regions) { + if (region->buffer.same_as(copy_read->buffer)) { + return; + } + } + regions.push_back(region); + } + + // Check if a region is present (by structural equality) + bool Contains(const BufferRegion ®ion) const { + for (const BufferRegion ©_read : regions) { + if (region->buffer.same_as(copy_read->buffer)) { + return true; + } + } + return false; + } + + size_t Size() const { return regions.size(); } + }; + + CopyStageDependencyReadsManager copy_stage_dependency_reads_mgr; + + // Step 1. Collect Copy reads + for (const auto &pinfo : pipeline_stage_infos) { + if (pinfo.is_copy_stage()) { + for (const BufferRegion &read : pinfo.reads) { + copy_stage_dependency_reads_mgr.AddUnique(read); + } + } + } + + // Step 2. find if pinfo write the copy reads, then update the + // copy_stage_dependency_reads To prevent infinite loops, we set a maximum + // number of iterations. In theory, the number of possible updates is + // bounded by the number of pipeline stages, since each stage can only be + // marked as producer_for_copy once, and each read can only be added once. + // But for safety, we add a hard limit. + const size_t max_iterations = (pipeline_stage_infos.size() * 4) + 16; + size_t iter_count = 0; + + for (auto &pinfo : pipeline_stage_infos) { + if (!pinfo.is_copy_stage()) { + continue; + } + auto original_copy_stmt_index = pinfo.original_stmt_index; + bool updated = true; + while (updated) { + updated = false; + for (auto &pinfo_inner : pipeline_stage_infos) { + if (pinfo_inner.is_copy_stage()) { + continue; + } + if (pinfo_inner.original_stmt_index >= original_copy_stmt_index) { + break; + } + + bool should_prepare = false; + for (const BufferRegion &write : pinfo_inner.writes) { + if (copy_stage_dependency_reads_mgr.Contains(write)) { + should_prepare = true; + break; + } + } + if (should_prepare && !pinfo_inner.is_producer_for_copy()) { + pinfo_inner.producer_for_copy = true; + updated = true; + } + if (should_prepare) { + for (const BufferRegion &read : pinfo_inner.reads) { + size_t before = copy_stage_dependency_reads_mgr.Size(); + copy_stage_dependency_reads_mgr.AddUnique(read); + if (copy_stage_dependency_reads_mgr.Size() > before) { + updated = true; + } + } + } + } + iter_count++; + if (iter_count > max_iterations) { + LOG(FATAL) + << "Pipeline planning: Exceeded maximum iterations (" + << max_iterations << ") in copy stage dependency propagation. " + << "This may indicate a cyclic or pathological dependency graph."; + } + } + } + + // Analysis use-def chain to determine last_use_stmt_index for copy + // operations This step is critical for pipeline optimization as it + // identifies the index of the last statement that consumes data produced by + // copy stages, enabling optimal placement of copy operations in the + // pipeline schedule. + for (auto &pinfo : pipeline_stage_infos) { + // Only analyze copy stages (memory copy operations) + if (!pinfo.is_first_stage()) + continue; + + // Check all subsequent statements to find the latest consumer + for (int i = pinfo.original_stmt_index + 1; + i < static_cast(pipeline_body_seq->size()); i++) { + + // Check if any read operation in statement 'i' uses data written by + // this copy stage + for (const BufferRegion &read : pipeline_stage_infos[i].reads) { + // Look for overlapping buffer regions between this stage's writes and + // stage 'i's reads + if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), + [&](const BufferRegion &r) { + return r->buffer == read->buffer && + MayConflict(r->region, read->region); + }) != pinfo.writes.end()) { + // Update last_use_stmt_index to the maximum (latest) statement + // index that uses this data This ensures we capture the final + // consumer of the copied data + pinfo.last_use_stmt_index = std::max(pinfo.last_use_stmt_index, i); + } + } + // Check for write-after-write conflicts (multiple stages writing to + // same buffer region) This is important for pipeline correctness and + // affects last_use_stmt_index analysis + if (pinfo.is_copy_stage()) { + for (const BufferRegion &write : pipeline_stage_infos[i].writes) { + if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), + [&](const BufferRegion &r) { + return r->buffer == write->buffer && + MayConflict(r->region, write->region); + }) != pinfo.writes.end()) { + LOG(FATAL) << "Pipeline planning error: Multiple writes to " + "overlapping buffer regions detected. " + << "Stage " << pinfo.original_stmt_index + << " and stage " << i + << " are both writing to buffer '" + << write->buffer->name + << "' with overlapping regions. This is not supported " + "in pipeline planning."; + } + } + } + } + } + + // Making stages and orders + int order_idx = 0; + // Stage 1. Create pipeline stages and assign order + for (auto &pinfo : pipeline_stage_infos) { + // Skip elements that must be in first stage: + // 1. Copy stages (with active last_use_stmt_index) - these need special + // handling + // because they have consumers that depend on their data + // 2. All Producer stages for copy stages. + if (pinfo.is_first_stage() && pinfo.is_last_use_stmt_index_valid()) { + continue; + } + + // Main logic stage assignment: + // - Increment order index + // - Assign to new stage (current num_stages) + pinfo.order = order_idx++; + pinfo.stage = num_stages; + + // Schedule copy stages that have this stage as their last consumer + // This ensures copy operations are placed right before their final + // consumer for optimal pipeline efficiency + for (auto &pinfo_1 : pipeline_stage_infos) { + if ((pinfo_1.is_first_stage() && + pinfo_1.last_use_stmt_index == pinfo.original_stmt_index)) { + pinfo_1.order = order_idx++; + pinfo_1.stage = 0; // Copy stages are typically assigned to stage 0 + } + } + } + + ICHECK(size_t(order_idx) == pipeline_stage_infos.size()) + << "The number of stages should be equal to the number of pipeline " + "stages. " + << "Got " << order_idx << " stages and " << pipeline_stage_infos.size() + << " pipeline stages."; + + // Step 2. if all the copy is at the end of the order, we can move these + // copy to the beginning of the order and shrink the stage offset by 1. + int copy_stage_at_end = [&]() { + int copy_stage_cnt = 0; + int copy_order_min = pipeline_stage_infos.size(); + int non_copy_order_max = 0; + for (auto &pinfo : pipeline_stage_infos) { + if (pinfo.is_first_stage()) { + copy_stage_cnt++; + copy_order_min = std::min(copy_order_min, pinfo.order); + } else { + non_copy_order_max = std::max(non_copy_order_max, pinfo.order); + } + } + if (copy_order_min > non_copy_order_max) + return copy_stage_cnt; + return -1; + }(); + if (copy_stage_at_end > 0 && num_stages >= 2) { + for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning + pinfo.order = + (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size(); + if (!pinfo.is_copy_stage() && !pinfo.is_producer_for_copy()) + pinfo.stage--; + } + } + + // Finally, make the pipeline annotation + Map annotations; + for (const auto &[key, value] : loop->annotations) { + if (key != "num_stages") { + annotations.Set(key, value); + } + } + + std::vector orders, stages; + orders.reserve(pipeline_stage_infos.size()); + stages.reserve(pipeline_stage_infos.size()); + for (auto &pinfo : pipeline_stage_infos) { + orders.push_back(pinfo.order); + stages.push_back(pinfo.stage); + } + + annotations.Set(tir::attr::software_pipeline_stage, Array(stages)); + annotations.Set(tir::attr::software_pipeline_order, Array(orders)); + if (TargetHasAsyncCopy(target_) && use_async_copy_) + annotations.Set(tir::attr::software_pipeline_async_stages, + Array{0}); + + return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body, + loop->thread_binding, annotations); + } + + Stmt VisitStmt_(const BlockNode *op) final { + for (const auto &buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + for (const auto &buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + return std::move(block); + } + + Map buffer_data_to_buffer_; + Target target_; + bool use_async_copy_{}; +}; + +tvm::transform::Pass PipelinePlanning() { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool use_async_copy = + ctx->GetConfig("tir.use_async_copy", Bool(true)).value(); + PrimFuncNode *fptr = f.CopyOnWrite(); + fptr->body = PipelinePlanner::Substitute(f, use_async_copy); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/plan_update_buffer_allocation_location.cc b/tilelang/original/src/transform/plan_update_buffer_allocation_location.cc new file mode 100644 index 0000000000000000000000000000000000000000..995b21519de8945a6cac13370d4028405dbfb448 --- /dev/null +++ b/tilelang/original/src/transform/plan_update_buffer_allocation_location.cc @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Planning where buffers to be allocated and update the AST. + * \file plan_update_buffer_allocation_location.cc + */ + +#include +#include +#include +#include +#include + +#include "tir/transforms/ir_utils.h" + +// Forward-declare tir's var-level LCA helper which has no public header. +namespace tvm { +namespace tir { +ffi::Map> +DetectBufferVarAccessLCA(const PrimFunc &func); +} +} // namespace tvm + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace tir::transform; + +// Use TVM's tir analysis API for LCA detection. + +class CollectManagedAllocations : public StmtExprVisitor { +public: + void VisitStmt_(const BlockNode *op) final { + for (const auto &buf : op->alloc_buffers) { + managed_allocations.insert(buf->data.get()); + } + for (const auto &buf : op->match_buffers) { + managed_allocations.insert(buf->buffer->data.get()); + } + StmtExprVisitor::VisitStmt_(op); + } + + /*! \brief Buffers that are allocated outside of the BlockNode, and should not + * be moved by BufferAllocationLocator. */ + std::unordered_set managed_allocations; +}; + +/*! \brief Collect the allocate buffer order. */ +class BufferAllocateOrderCollector : public StmtExprVisitor { +public: + static ffi::Array Collect(const PrimFunc &func) { + BufferAllocateOrderCollector collector; + for (const auto &kv : func->buffer_map) { + collector.buffer_alloc_recorder_.push_back(kv.second); + } + collector(func->body); + return std::move(collector.buffer_alloc_recorder_); + } + +private: + bool find(const Buffer &buf) { + return std::find(buffer_alloc_recorder_.begin(), + buffer_alloc_recorder_.end(), + buf) != buffer_alloc_recorder_.end(); + } + + void VisitStmt_(const BlockNode *op) final { + for (const Buffer &buffer : op->alloc_buffers) { + buffer_alloc_recorder_.push_back(buffer); + } + // Also visit match_buffers to collect constant buffers associated with + // AllocateConst nodes. These buffers only appear in read and match_buffer + // regions. + for (const auto ®ion : op->match_buffers) { + if (!find(region->source->buffer)) { + buffer_alloc_recorder_.push_back(region->source->buffer); + } + } + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (!find(op->buffer)) { + buffer_alloc_recorder_.push_back(op->buffer); + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (!find(op->buffer)) { + buffer_alloc_recorder_.push_back(op->buffer); + } + StmtExprVisitor::VisitStmt_(op); + } + + /*! \brief The buffer allocated order recorder. */ + ffi::Array buffer_alloc_recorder_; +}; + +class BufferAllocationLocator : public StmtExprMutator { +public: + explicit BufferAllocationLocator(const PrimFunc &func) { + // Use TVM's tir LCA detection implementation + ffi::Map> buffer_lca = + tir::DetectBufferAccessLCA(func); + ffi::Map> var_lca = + tir::DetectBufferVarAccessLCA(func); + + // The buffer_alloc_recorder Array is used to keep the buffer allocation + // order since the buffer_lca Map is unordered. + ffi::Array buffer_alloc_recorder = + BufferAllocateOrderCollector::Collect(func); + std::unordered_set arg_buffer_vars; + CollectManagedAllocations collector; + collector(func->body); + managed_allocations_ = collector.managed_allocations; + + for (const auto &kv : func->buffer_map) { + const Buffer &buffer = kv.second; + arg_buffer_vars.emplace(buffer->data.get()); + PushBinding(buffer->data, buffer); + } + // create buffers to be allocated at each stmts + for (const auto &buffer : buffer_alloc_recorder) { + // Prefer the LCA derived from the underlying data var. If missing, fall + // back to Buffer LCA. + const StmtNode *stmt = nullptr; + auto vit = var_lca.find(buffer->data); + if (vit != var_lca.end()) { + stmt = (*vit).second.get(); + } else { + auto bit = buffer_lca.find(buffer); + if (bit != buffer_lca.end()) { + stmt = (*bit).second.get(); + } + } + if (stmt != nullptr || vit != var_lca.end()) { + if (arg_buffer_vars.count(buffer->data.get())) { + continue; + } + if (managed_allocations_.count(buffer->data.get())) { + alloc_buffers_[stmt].push_back(buffer); + } + // Do not push binding here. Bindings should reflect scope accurately, + // and will be pushed/popped when visiting the owning stmt. + } + } + } + +private: + // Maintain a stack of Buffers per data var to correctly handle cases + // where multiple Buffer objects share the same underlying data Var. + void PushBinding(const Var &v, const Buffer &buf) { + ffi::Array arr; + auto it = buffer_data_to_buffers_.find(v); + if (it != buffer_data_to_buffers_.end()) { + arr = (*it).second; + } + arr.push_back(buf); + buffer_data_to_buffers_.Set(v, arr); + } + + void PopBinding(const Var &v) { + auto it = buffer_data_to_buffers_.find(v); + if (it == buffer_data_to_buffers_.end()) + return; + ffi::Array arr = (*it).second; + if (!arr.empty()) { + // erase last element + std::vector tmp; + tmp.reserve(arr.size() - 1); + for (size_t i = 0; i + 1 < arr.size(); ++i) + tmp.push_back(arr[i]); + arr = ffi::Array(tmp); + } + if (arr.empty()) { + buffer_data_to_buffers_.erase(v); + } else { + buffer_data_to_buffers_.Set(v, arr); + } + } + + bool HasBinding(const Var &v) const { + auto it = buffer_data_to_buffers_.find(v); + return it != buffer_data_to_buffers_.end() && !(*it).second.empty(); + } + + // Snapshot the current top binding per Var for APIs that require + // a single Buffer per data Var (e.g. GetBlockReadWriteRegion). + ffi::Map SnapshotVarMap() const { + ffi::Map out; + for (const auto &kv : buffer_data_to_buffers_) { + const Var &v = kv.first; + const ffi::Array &arr = kv.second; + if (!arr.empty()) { + out.Set(v, arr[arr.size() - 1]); + } + } + return out; + } + + Stmt VisitStmt_(const ForNode *op) final { + auto it = alloc_buffers_.find(op); + if (it == alloc_buffers_.end()) { + return StmtMutator::VisitStmt_(op); + } + for (const Buffer &buf : it->second) { + PushBinding(buf->data, buf); + } + auto node = Downcast(StmtMutator::VisitStmt_(op)); + ffi::Array new_block_alloc_bufs; + for (const Buffer &buf : it->second) { + if (managed_allocations_.count(buf->data.get())) { + PopBinding(buf->data); + new_block_alloc_bufs.push_back(buf); + } + } + + if (!new_block_alloc_bufs.empty()) { + node.CopyOnWrite()->body = + InjectOpaqueBlock(node->body, new_block_alloc_bufs); + } + + return node; + } + + Stmt VisitStmt_(const BlockNode *op) final { + ICHECK(!op->init.defined()); + ffi::Array alloc_buffers; + auto it = alloc_buffers_.find(op); + if (it != alloc_buffers_.end()) { + alloc_buffers = it->second; + for (const Buffer &buf : it->second) { + PushBinding(buf->data, buf); + } + } + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var &target_var = match_buffer->buffer->data; + const Var &source_var = match_buffer->source->buffer->data; + ICHECK(HasBinding(source_var)); + PushBinding(target_var, match_buffer->buffer); + } + Stmt stmt = StmtMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + // No longer consider buffers created by match_buffer inside the block when + // updating access region. + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var &target_var = match_buffer->buffer->data; + PopBinding(target_var); + } + // No longer consider buffers allocated inside the block when updating + // access region. + if (it != alloc_buffers_.end()) { + for (const Buffer &buf : it->second) { + PopBinding(buf->data); + } + } + + ObjectPtr n = CopyOnWrite(op); + n->alloc_buffers = std::move(alloc_buffers); + // Erase buffer allocated inside the block from access region. + n->reads = RemoveRedundantBufferRegion(n->reads); + n->writes = RemoveRedundantBufferRegion(n->writes); + return Stmt(n); + } + + Stmt VisitStmt_(const BufferRealizeNode *op) final { + ICHECK(false) + << "Internal Error: BufferRealizeNode is not allowed in TensorIR."; + throw; + } + + Stmt InjectOpaqueBlock(Stmt body, const ffi::Array &alloc_buffers) { + ICHECK(!alloc_buffers.empty()); + Block opaque_block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/"", + /*body=*/std::move(body), + /*init=*/std::nullopt, + /*alloc_buffers=*/alloc_buffers); + ObjectPtr n = CopyOnWrite(opaque_block.get()); + // Snapshot to a Var->Buffer map using the innermost binding for each Var. + ffi::Map var_map = SnapshotVarMap(); + ffi::Array> access = + GetBlockReadWriteRegion(opaque_block, var_map); + n->reads = access[0]; + n->writes = access[1]; + BlockRealize realize({}, Bool(true), Block(n)); + return realize; + } + + ffi::Array + RemoveRedundantBufferRegion(const ffi::Array ®ion) const { + ffi::Array result; + for (const BufferRegion &buffer_region : region) { + if (HasBinding(buffer_region->buffer->data)) { + result.push_back(buffer_region); + } + } + return result; + } + + /*! \brief The map from stmt to the buffers to be allocated under it. */ + std::unordered_map> alloc_buffers_; + /*! \brief Stack of buffers per data var for scoping correctness. */ + ffi::Map> buffer_data_to_buffers_; + /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ + std::unordered_set managed_allocations_; +}; + +PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + BufferAllocationLocator locator(func); + fptr->body = locator(fptr->body); + return func; +} + +namespace transform { + +Pass PlanAndUpdateBufferAllocationLocation() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return ::tvm::tl::PlanAndUpdateBufferAllocationLocation(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, + "tl.PlanAndUpdateBufferAllocationLocation", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.PlanAndUpdateBufferAllocationLocation", + PlanAndUpdateBufferAllocationLocation); +} + +} // namespace transform + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/simplify.cc b/tilelang/original/src/transform/simplify.cc new file mode 100644 index 0000000000000000000000000000000000000000..c10d5687a7aecf3ce698a53d80cf6fd21e314e1f --- /dev/null +++ b/tilelang/original/src/transform/simplify.cc @@ -0,0 +1,546 @@ +/*! + * \file simplify.cc + * \brief Statement simplifier based on analyzer and remove useless parameters + * of TL PrimFunc. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "arith/ir_mutator_with_analyzer.h" +#include "tir/analysis/control_flow_graph.h" +#include "tir/analysis/var_use_def_analysis.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace ffi; +using namespace arith; + +struct SimplifyConfigNode : public AttrsNodeReflAdapter { + bool transitively_prove_inequalities{}; + bool propagate_knowns_to_prove_conditional{}; + bool propagate_knowns_to_simplify_expressions{}; + bool convert_boolean_to_and_of_ors{}; + bool apply_constraints_to_boolean_branches{}; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("transitively_prove_inequalities", + &SimplifyConfigNode::transitively_prove_inequalities, + "If true, simplify conditionals with transitive combinations " + "of scoped constraints", + refl::DefaultValue(false)) + .def_ro("propagate_knowns_to_prove_conditional", + &SimplifyConfigNode::propagate_knowns_to_prove_conditional, + "If true, known buffer values are propagated and used to " + "statically prove conditionals", + refl::DefaultValue(false)) + .def_ro("propagate_knowns_to_simplify_expressions", + &SimplifyConfigNode::propagate_knowns_to_simplify_expressions, + "If true, known buffer values are propagated and used to " + "replace BufferLoad wherever " + "possible", + refl::DefaultValue(false)) + .def_ro("convert_boolean_to_and_of_ors", + &SimplifyConfigNode::convert_boolean_to_and_of_ors, + "If true, simplify conditionals into an AND of ORs", + refl::DefaultValue(false)) + .def_ro("apply_constraints_to_boolean_branches", + &SimplifyConfigNode::apply_constraints_to_boolean_branches, + "If true, simplify each branch of AND/OR under a constraints " + "provided by the other " + "branch", + refl::DefaultValue(false)); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig", + SimplifyConfigNode, BaseAttrsNode); + + RewriteSimplifier::Extension GetEnabledExtensions() const { + RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; + if (transitively_prove_inequalities) { + flags = RewriteSimplifier::Extension( + flags | RewriteSimplifier::kTransitivelyProveInequalities); + } + if (convert_boolean_to_and_of_ors) { + flags = RewriteSimplifier::Extension( + flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); + } + if (apply_constraints_to_boolean_branches) { + flags = RewriteSimplifier::Extension( + flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches); + } + return flags; + } +}; + +std::unordered_set +CollectUsedBuffers(const PrimFunc &func) { + struct Visitor : StmtExprVisitor { + using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; + + Visitor(PrimFunc func) : func(std::move(func)) {} + + void VisitExpr_(const CallNode *op) override { + for (const auto &arg : op->args) { + for (const auto &it : func->buffer_map) { + if (Downcast(it.second.get()->data).same_as(arg)) { + used_in_buffer_def_.insert(it.second.get()); + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const BufferLoadNode *op) override { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const BufferStoreNode *op) override { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const BlockNode *op) override { + for (const auto &buffer : op->alloc_buffers) { + for (const auto &it : func->buffer_map) { + if (it.second.get()->data.same_as(buffer.get()->data)) { + used_in_buffer_def_.insert(it.second.get()); + } + } + } + for (const auto &buffer : op->reads) { + for (const auto &it : func->buffer_map) { + if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { + used_in_buffer_def_.insert(it.second.get()); + } + } + } + for (const auto &buffer : op->writes) { + for (const auto &it : func->buffer_map) { + if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { + used_in_buffer_def_.insert(it.second.get()); + } + } + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitBuffer(const Buffer &buf) { + // Collect buffers that should remain defined + VarUseDefAnalyzer usage(Array{}); + usage(buf->data); + for (const auto &dim : buf->shape) { + usage(dim); + } + for (const auto &dim : buf->strides) { + usage(dim); + } + usage(buf->elem_offset); + + for (const auto &buffer : usage.buffer_use_count_) { + if (buffer.second >= 1) { + used_in_buffer_def_.insert(buffer.first); + } + } + for (const auto &buffer : usage.undefined_buffers_) { + used_in_buffer_def_.insert(buffer.get()); + } + } + PrimFunc func; + std::unordered_set used_in_buffer_def_; + }; + + Visitor visitor(func); + visitor(func->body); + return visitor.used_in_buffer_def_; +} + +/* \brief Utility function to collect vars that should be retained. Used in + * Letstmt Only + */ +std::unordered_set +CollectVarsUsedInBufferDefinition(const Stmt &stmt) { + struct Visitor : StmtExprVisitor { + using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; + + void VisitExpr_(const BufferLoadNode *op) override { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const BufferStoreNode *op) override { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitBuffer(const Buffer &buf) { + // Collect variables that should remain defined + VarUseDefAnalyzer usage(Array{}); + usage(buf->data); + for (const auto &dim : buf->shape) { + usage(dim); + } + for (const auto &dim : buf->strides) { + usage(dim); + } + usage(buf->elem_offset); + + // Track for use in LetStmtNode mutator + for (const auto &var : usage.undefined_) { + used_in_buffer_def_.insert(var.get()); + } + } + std::unordered_set used_in_buffer_def_; + }; + + Visitor visitor; + visitor(stmt); + return visitor.used_in_buffer_def_; +} + +class SimplifyConfig : public Attrs { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, + SimplifyConfigNode); +}; +TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } + +TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); + +class StmtSimplifier : public IRMutatorWithAnalyzer { +public: + static PrimFunc + Apply(PrimFunc func, Analyzer *analyzer, + const Optional &config_opt = std::nullopt, + bool simplify_arguments = false) { + auto config = config_opt.value_or(AttrsWithDefaultValues()); + analyzer->rewrite_simplify.SetEnabledExtensions( + config->GetEnabledExtensions()); + + std::optional touch_pattern = std::nullopt; + if (config->propagate_knowns_to_prove_conditional || + config->propagate_knowns_to_simplify_expressions) { + touch_pattern = ControlFlowGraph(func->body); + } + + std::unordered_set used_in_buffer_def = + CollectVarsUsedInBufferDefinition(func->body); + StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), + std::move(used_in_buffer_def)); + simplifier.MarkBufferMapShapes(func); + func.CopyOnWrite()->body = simplifier(func->body); + + // Optionally remove unused buffer parameters + if (simplify_arguments) { + // First get used buffers + simplifier.used_buffers_ = CollectUsedBuffers(func); + + bool param_updated = false; + Array new_params; + Map new_buffer_map; + // Check whether each buffer is used + for (const auto &var : func->params) { + if (func->buffer_map.find(var) != func->buffer_map.end()) { + if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != + simplifier.used_buffers_.end()) { + new_params.push_back(var); + new_buffer_map.Set(var, func->buffer_map[var]); + } else if (simplifier.used_in_buffer_def_.find( + func->buffer_map[var]->data.get()) != + simplifier.used_in_buffer_def_.end()) { + new_params.push_back(var); + new_buffer_map.Set(var, func->buffer_map[var]); + } else { + param_updated = true; + } + } else { + // Non-buffer parameters (e.g., scalars) are always retained + new_params.push_back(var); + } + } + + if (param_updated) { + return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, + new_buffer_map, func->attrs, func->span); + } + } + // Either no change to params or argument simplification disabled + return func; + } + +private: + explicit StmtSimplifier( + Analyzer *analyzer, SimplifyConfig config, + std::optional touch_pattern, + std::unordered_set used_in_buffer_def) + : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)), + touch_pattern_(std::move(touch_pattern)), + used_in_buffer_def_(std::move(used_in_buffer_def)) {} + + using Parent = IRMutatorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + PrimExpr VisitExpr(const PrimExpr &expr) final { + if (config_->propagate_knowns_to_simplify_expressions) { + return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), + analyzer_); + } else { + return analyzer_->Simplify(expr); + } + } + + Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } + + Stmt VisitStmt(const Stmt &stmt) override { + Optional cache = this->current_stmt_; + this->current_stmt_ = stmt; + Stmt output = Parent::VisitStmt(stmt); + this->current_stmt_ = std::move(cache); + return output; + } + + Stmt VisitStmt_(const ForNode *op) final { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + With ctx1(analyzer_, op->loop_var >= op->min); + With ctx2(analyzer_, + op->loop_var < op->min + op->extent); + return Parent::VisitStmt_(op); + } + + bool CanInlineLetStmt(const LetStmtNode *op) { + if (is_const_number(op->value)) + return true; + if (op->value.as()) + return true; + // Won't face the deep expression explosion problem as in Let expression. + // attempt to inline as much as possible if the value integer type(can be + // index). + if (!op->value.dtype().is_int()) + return false; + return SideEffect(op->value) <= CallEffectKind::kPure; + } + + Stmt VisitStmt_(const LetStmtNode *op) override { + PrimExpr value = this->VisitExpr(op->value); + bool remove_buffer_alias = false; + // TileLang emits aliases like `X_shared = buffer[0:128, 0:32]` to annotate + // fragment types. TVM currently reinterprets vectorized/shared accesses as + // Let-bound BufferLoad/BufferRegion nodes. If these bindings survive, later + // passes (Layout rewrite, FlattenBuffer) substitute them with vector lanes + // that our layout can't handle. Force-inline (by dropping the let) whenever + // the alias spans more than 2 dims or carries vector lanes. + auto get_ranges = [&](const PrimExpr &expr) -> Array { + Array ranges; + if (const auto *load = expr.as()) { + for (const PrimExpr &index : load->indices) { + if (const auto *ramp = index.as()) { + ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + } else { + ranges.push_back(Range::FromMinExtent(index, Integer(1))); + } + } + } else if (const auto *region = expr.as()) { + for (const Range &range : region->region) { + ranges.push_back(range); + } + } + return ranges; + }; + Array ranges = get_ranges(value); + if (!ranges.empty()) { + int non_unit_dims = 0; + for (const Range &range : ranges) { + PrimExpr extent = analyzer_->Simplify(range->extent); + if (is_const_int(extent, 1) || analyzer_->CanProveEqual(extent, 1)) { + continue; + } + ++non_unit_dims; + if (non_unit_dims > 1) { + remove_buffer_alias = true; + break; + } + } + } + if (remove_buffer_alias) { + Stmt body = this->VisitStmt(op->body); + bool used = UsesVar( + body, [&](const VarNode *var) { return var == op->var.get(); }); + ICHECK(!used) << "Let binding of BufferLoad is expected to be unused " + "before removal " + << op->var << " : " << op->value << " ."; + return body; + } + + bool can_inline = CanInlineLetStmt(op); + if (can_inline) { + analyzer_->Bind(op->var, value); + } else if (SideEffect(op->value) <= CallEffectKind::kPure) { + non_inlined_bindings_.Set(op->var, value); + } + Stmt body = this->VisitStmt(op->body); + + bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); + + if (can_inline && !used_in_buffer_def) { + return body; + } else if (value.same_as(op->value) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + auto n = this->CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); + } + } + + Stmt VisitStmt_(const IfThenElseNode *op) override { + if (Optional cond = ProveCondition(op->condition)) { + if (cond.value()->value) { + return this->VisitStmt(op->then_case); + } else if (op->else_case) { + return this->VisitStmt(op->else_case.value()); + } else { + return Evaluate(0); + } + } else { + return Parent::VisitStmt_(op); + } + } + + PrimExpr VisitExpr_(const CallNode *op) override { + if (op->op.same_as(builtin::if_then_else())) { + if (Optional cond = ProveCondition(op->args[0])) { + if (cond.value()->value) { + return this->VisitExpr(op->args[1]); + } else { + return this->VisitExpr(op->args[2]); + } + } + } + return Parent::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const VarNode *op) override { + used_vars_.insert(op); + return Parent::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) override { + auto buffer = op->buffer.get(); + if (used_buffers_.find(buffer) == used_buffers_.end()) { + used_buffers_.insert(buffer); + } + return Parent::VisitExpr_(op); + } + + // eliminate useless stores + Stmt VisitStmt_(const BufferStoreNode *op) override { + BufferStore store = Downcast(Parent::VisitStmt_(op)); + if (const BufferLoadNode *load = store->value.as()) { + if (load->buffer->data.same_as(store->buffer->data) && + ArrayDeepEqual(load->indices, store->indices) && + tir::ExprDeepEqual()(load->buffer->elem_offset, + store->buffer->elem_offset) && + ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && + ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { + return Evaluate(0); + } + } + auto buffer = op->buffer.get(); + if (used_buffers_.find(buffer) == used_buffers_.end()) { + used_buffers_.insert(buffer); + } + return std::move(store); + } + + Stmt VisitStmt_(const AttrStmtNode *op) override { + if (op->attr_key == "tl.assume") { + PrimExpr condition = this->VisitExpr(Downcast(op->node)); + auto n = CopyOnWrite(op); + n->node = std::move(condition); + return Parent::VisitStmt_(n.get()); + } + return Parent::VisitStmt_(op); + } + +private: + bool ArrayDeepEqual(const Array &lhs, const Array &rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { + return false; + } + } + return true; + } + + /* \brief Internal utility for checking conditionals + * + * Uses more aggressive optimization, such as performing additional + * inlining and tracking known buffer values. + */ + Optional ProveCondition(PrimExpr condition) const { + condition = Substitute(condition, non_inlined_bindings_); + if (config_->propagate_knowns_to_prove_conditional) { + ICHECK(touch_pattern_.has_value()); + condition = touch_pattern_->SimplifyInContext( + condition, current_stmt_.value(), analyzer_); + } else { + condition = analyzer_->Simplify(condition); + } + if (const int64_t *as_int = as_const_int(condition)) { + return Bool(*as_int); + } else { + // May have symbolic, need kSymbolicBound level prover. + if (analyzer_->CanProve(condition) || + analyzer_->CanProve(condition, + arith::ProofStrength::kSymbolicBound)) { + return Bool(true); + } + return std::nullopt; + } + } + + SimplifyConfig config_; + std::optional touch_pattern_; + + Map non_inlined_bindings_; + Optional current_stmt_{std::nullopt}; + std::unordered_set used_in_buffer_def_; + std::unordered_set used_vars_; + std::unordered_set used_buffers_; +}; + +using namespace tir::transform; + +tvm::transform::Pass Simplify(bool simplify_arguments = true) { + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + arith::Analyzer analyzer; + auto cfg = ctx->GetConfig("tl.Simplify"); + return StmtSimplifier::Apply(std::move(f), &analyzer, cfg, + simplify_arguments); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.Simplify", Simplify); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/split_host_device.cc b/tilelang/original/src/transform/split_host_device.cc new file mode 100644 index 0000000000000000000000000000000000000000..bfdcb5cd55cd6e2cb666c8d2a0ccc5b61974d0d7 --- /dev/null +++ b/tilelang/original/src/transform/split_host_device.cc @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file split_host_device.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "common/assume.h" +#include "tir/analysis/var_use_def_analysis.h" +#include "tvm/node/cast.h" +#include "tvm/runtime/logging.h" +#include "tvm/tir/stmt.h" + +namespace tvm { +namespace tl { +using namespace ffi; +namespace tir = tvm::tir; + +// This pass traverses the AST, split the target function into host part and +// device part and copies all assume attribute statements to the device side. + +// 1. Traverse AST and collect all assume statements into host_assumes_. +// 2. Until the first AttrStmtNode with tvm::attr::kTarget. +// 3. Call SplitDeviceFunc, which will create a new device function and replace +// the original body with a call to that function. +class HostDeviceSplitter : public tir::StmtMutator { +public: + explicit HostDeviceSplitter(IRModule *device_mod, + std::function var_supply) + : device_mod_(device_mod), var_supply_(std::move(var_supply)) {} + + void SetNonRestrictParams(Optional> params) { + for (auto param : params.value()) { + non_restrict_params_.push_back(param); + } + } + + tir::Stmt VisitStmt_(const tir::AttrStmtNode *op) final { + if (op->attr_key == tvm::attr::kTarget) { + found_device_region_ = true; + auto device_target = op->node.as().value().WithoutHost(); + return SplitDeviceFunc(op->body, device_target); + } else if (op->attr_key == tir::attr::tilelang_assume) { + // NOTE(chaofan): the assumes collected here must be in host-side. + // This is because when the collector reaches the split region, + // it will start to split and return. For safety, we add a check here. + ICHECK(!found_device_region_) + << "Assumes collection should not be in device region."; + // We first push back the outside assume, then visit the child. + // So when moving assumes to device side, we need to do the building + // process in a reverse order. + host_assumes_.push_back(op); + } + return tir::StmtMutator::VisitStmt_(op); + } + + tir::Stmt VisitStmt_(const tir::EvaluateNode *op) final { + auto stmt = GetRef(op); + // There should be no assume in evaluate form after InjectAssumes. + ICHECK(!IsAssumeInEvaluateForm(stmt)) + << "Unexpected assume in evaluate form. Please run InjectAssumes pass " + "first."; + return tir::StmtMutator::VisitStmt_(op); + } + + tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) { + return SplitDeviceFunc(std::move(body), std::move(device_target)); + } + + bool found_device_region() const { return found_device_region_; } + +private: + bool found_device_region_{false}; + Array non_restrict_params_; + + Stmt wrapBodyWithHostSideAssumes(Stmt body) { + for (auto it = host_assumes_.rbegin(); it != host_assumes_.rend(); ++it) { + body = + AttrStmt((*it)->node, tir::attr::tilelang_assume, (*it)->value, body); + } + return body; + } + + tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) { + + auto [params, buffers_to_declare] = + [&]() -> std::tuple, Array> { + tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, + /*visit_thread_extent=*/true); + use_def(body); + + // Sort first by variable type, then by variable name + std::vector params{use_def.undefined_.begin(), + use_def.undefined_.end()}; + std::sort(params.begin(), params.end(), + [](const tir::Var &a, const tir::Var &b) { + auto sort_key = [](const tir::Var &var) { + return std::tuple{ + !var->dtype.is_handle(), + var->name_hint, + }; + }; + return sort_key(a) < sort_key(b); + }); + return {params, use_def.undefined_buffers_}; + }(); + + // CodeGenCPU is used for some device-side targets, such as + // "ext_dev", and expects to be able to return a int32_t status + // code. + + bool can_propagate_errors = [&]() { + auto kind = device_target->GetTargetDeviceType(); + return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon; + }(); + IntImm success(DataType::Int(32), 0); + Type kernel_ret_type; + if (can_propagate_errors) { + kernel_ret_type = PrimType(DataType::Int(32)); + body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success))); + } else { + kernel_ret_type = VoidType(); + } + + // Declare necessary buffers for the device side. + for (tir::Buffer buf : buffers_to_declare) { + body = tir::DeclBuffer(buf, std::move(body)); + } + + // Copy assumes from host-side to device-side. + body = wrapBodyWithHostSideAssumes(body); + + tir::PrimFunc device_func(params, body, kernel_ret_type); + device_func = + WithAttrs(std::move(device_func), + {{tvm::attr::kTarget, device_target}, + {tir::attr::kNoAlias, true}, + {tir::attr::kIsGlobalFunc, true}, + {tl::attr::kNonRestrictParams, non_restrict_params_}}); + + GlobalVar kernel_symbol_global = var_supply_(); + (*device_mod_)->Add(kernel_symbol_global, device_func); + Array args = + params.Map([](const tir::Var &var) -> PrimExpr { return var; }); + + if (can_propagate_errors) { + tir::Var kernel_error_code("kernel_error_code", success->dtype); + tir::Call kernel_call(success->dtype, kernel_symbol_global, args); + tir::AssertStmt assert_success( + kernel_error_code == success, + tir::StringImm("Error executing compute kernel"), tir::Evaluate(0)); + tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success); + + return let_check; + + } else { + return tir::Evaluate( + tir::Call(DataType::Void(), kernel_symbol_global, args)); + } + } + + // target ir module + IRModule *device_mod_; + // Generate new GlobalVar for the kernel + std::function var_supply_; + // Collect assumes in host side + Array host_assumes_; +}; + +tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod, + std::function var_supply) { + HostDeviceSplitter splitter(device_mod, std::move(var_supply)); + // Propagate non-restrict parameter list from host func to device kernels + if (auto opt = func->GetAttr>(tl::attr::kNonRestrictParams)) { + splitter.SetNonRestrictParams(opt.value()); + // Remove the attribute from host-side PrimFunc; it only matters for device + // codegen. + func = tvm::WithoutAttr(std::move(func), tl::attr::kNonRestrictParams); + } + + if (auto body = splitter(func->body); !body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } else if (!splitter.found_device_region()) { + if (auto target = func->GetAttr(tvm::attr::kTarget)) { + auto device_target = target.value().WithoutHost(); + if (device_target.defined() && + func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && + tir::is_no_op(func->body)) { + if (auto forced = splitter.ForceSplit(func->body, device_target); + !forced.same_as(func->body)) { + func.CopyOnWrite()->body = forced; + } + } + } + } + return func; +} + +namespace transform { + +tvm::transform::Pass SplitHostDevice() { + auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) { + tvm::GlobalVarSupply global_var_supply(mod); + + IRModule device_mod = IRModule(Map({})); + IRModule updates = IRModule(Map({})); + + for (const auto &[gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + tir::PrimFunc func = opt.value(); + + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto name_prefix = global_symbol.value_or(gvar->name_hint); + auto kernel_name = name_prefix + "_kernel"; + auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { + return global_var_supply->FreshGlobal(kernel_name, false); + }; + + func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod, + var_supply); + if (!func.same_as(base_func)) { + updates->Add(gvar, func); + } + } + } + mod->Update(updates); + mod->Update(device_mod); + return tir::transform::ConvertSSA()(mod); + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice", + {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice); +} + +} // namespace transform +} // namespace tl +} // namespace tvm \ No newline at end of file diff --git a/tilelang/original/src/transform/storage_access.cc b/tilelang/original/src/transform/storage_access.cc new file mode 100644 index 0000000000000000000000000000000000000000..49c839929cecdce6ffe9ae61e1ae4be414c33b70 --- /dev/null +++ b/tilelang/original/src/transform/storage_access.cc @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file storage_access.cc + */ +#include "storage_access.h" + +#include +#include +#include + +#include +#include + +#include "../op/builtin.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { + Var buf = op->buffer->data; + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { + ICHECK(allow_append_) << tvm::ffi::GetRef(op) << " " + << scope.to_string(); + AccessEntry e; + e.threads = env_threads(); + e.thread_range = this->ComputeThreadRange(e.threads); + e.buffer = buf; + e.buffer_indices = op->indices; + e.dtype = op->dtype.element_of(); + for (const auto &index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } + e.type = kRead; + e.scope = scope; + curr_stmt_.access.emplace_back(std::move(e)); + } + // traverse child + IRVisitorWithAnalyzer::VisitExpr_(op); +} + +void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { + allow_append_ = true; + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + + Var buf = op->buffer->data; + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { + AccessEntry e; + e.threads = env_threads(); + e.thread_range = this->ComputeThreadRange(e.threads); + e.buffer = buf; + e.buffer_indices = op->indices; + e.dtype = op->value.dtype().element_of(); + for (const auto &index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } + e.type = kWrite; + e.scope = scope; + curr_stmt_.access.emplace_back(std::move(e)); + } + // traverse child + IRVisitorWithAnalyzer::VisitStmt_(op); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); + allow_append_ = false; +} + +void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) { + allow_append_ = true; + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + IRVisitorWithAnalyzer::VisitStmt_(op); + // push to the scope + if (!curr_stmt_.access.empty()) { + scope_.back().push_back(curr_stmt_); + curr_stmt_.access.clear(); + } + allow_append_ = false; +} + +void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) { + allow_append_ = true; + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + this->VisitExpr(op->value); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); + allow_append_ = false; + // traverse body block + this->VisitStmt(op->body); +} + +void TileLangStorageAccessVisitor::VisitStmt_(const BlockNode *op) { + auto block = Downcast(op); + for (const auto &buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + IRVisitorWithAnalyzer::VisitStmt_(op); +} + +void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tvm::tir::attr::double_buffer_write) { + ICHECK(double_buffer_write_ == nullptr); + double_buffer_write_ = op->node.as(); + scope_.push_back(std::vector()); + IRVisitorWithAnalyzer::VisitStmt_(op); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + if (!s.access.empty()) { + for (AccessEntry &e : s.access) { + if (e.type == kWrite && e.buffer.get() == double_buffer_write_) { + e.double_buffer_write = true; + } + } + scope_.back().emplace_back(std::move(s)); + } + double_buffer_write_ = nullptr; + } else if (op->attr_key == tvm::tir::attr::coproc_scope) { + IterVar iv = Downcast(op->node); + env_threads_.push_back(iv); + IRVisitorWithAnalyzer::VisitStmt_(op); + env_threads_.pop_back(); + } else if (op->attr_key == tvm::tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + env_threads_.push_back(iv); + ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_.Bind( + iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + + if (!in_device_env_) { + in_device_env_ = true; + scope_.push_back(std::vector()); + IRVisitorWithAnalyzer::VisitStmt_(op); + // no need to take the result as the thread barrier automatically syncs. + Summarize(std::move(scope_.back()), nullptr); + in_device_env_ = false; + scope_.pop_back(); + } else { + IRVisitorWithAnalyzer::VisitStmt_(op); + } + env_threads_.pop_back(); + } else if (op->attr_key == tvm::tir::attr::hand_threaded) { + // skip this pass on blocks that were hand_threaded + // this avoids control flow and read/write conflicts + // between hand-threaded kernels and automatic threading + } else { + IRVisitorWithAnalyzer::VisitStmt_(op); + } +} + +void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) { + scope_.push_back(std::vector()); + IRVisitorWithAnalyzer::VisitStmt_(op); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), op); + scope_.pop_back(); + if (!s.access.empty()) { + // relax the touched set to contain all ranges in the loop. + std::unordered_map relax_map; + relax_map[op->loop_var.get()] = + arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); + for (AccessEntry &e : s.access) { + if (e.buffer.defined()) { + ICHECK(!e.touched.empty()); + Array new_touched; + for (const auto &touched : e.touched) { + new_touched.push_back(arith::EvalSet(touched, relax_map)); + } + e.touched = std::move(new_touched); + } + } + } + if (!s.access.empty()) { + scope_.back().emplace_back(std::move(s)); + } +} + +bool IsThreadInvariant(const PrimExpr &cond) { + if (auto call = cond.as()) { + if (auto opt_call_op = call->op.as()) { + const auto &call_op = opt_call_op.value(); + if (call_op.same_as(builtin::tvm_thread_invariant())) { + return true; + } + } + } + return false; +} + +/** + * @brief Visit an IfThenElse statement and collect storage access summaries for + * its branches. + * + * Visits the if-then-else node's condition and both branches to summarize + * buffer reads, writes, and synchronization events under the condition's + * constraints. If the condition is not thread-invariant, increments an internal + * condition counter for the duration of processing. + * + * Behavior and side effects: + * - Evaluates the condition expression (using ExtractRealCondition) and applies + * it as a constraint while summarizing the then-branch. + * - For the else-branch (when present), applies the negated, + * analyzer-simplified condition + * (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint. + * - Accumulates summarized StmtEntry access information for the then/else + * branches and appends a combined StmtEntry for the IfThenElseNode into the + * current scope. + * - Temporarily toggles allow_append_ and clears curr_stmt_.access during + * condition evaluation and branch summarization. + * - Modifies internal state: scope_ (push/pop of temporary branch scopes), + * curr_stmt_.access, and condition_counter_ (incremented/decremented when the + * condition is not thread-invariant). + */ +void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) { + bool is_thread_invariant = IsThreadInvariant(op->condition); + if (!is_thread_invariant) { + ++condition_counter_; + } + + allow_append_ = true; + this->VisitExpr(op->condition); + PrimExpr real_condition = ExtractRealCondition(op->condition); + + // Preserve accesses collected from the condition expression so they + // participate in dependency analysis. Otherwise, a write to shared memory + // immediately followed by an if-condition reading that memory would not + // trigger a sync before the if-statement. + std::vector cond_access = std::move(curr_stmt_.access); + allow_append_ = false; + + scope_.push_back(std::vector()); + { + With constraint(&analyzer_, real_condition); + this->VisitStmt(op->then_case); + } + + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + // Merge the condition's access summary into the if-statement's access list + // so the planner can insert a sync before the if when necessary. + if (!cond_access.empty()) { + s.access.insert(s.access.begin(), cond_access.begin(), cond_access.end()); + } + if (op->else_case) { + scope_.push_back(std::vector()); + { + With constraint( + &analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); + this->VisitStmt(op->else_case.value()); + } + auto v = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + s.access.insert(s.access.end(), v.begin(), v.end()); + } + scope_.back().emplace_back(std::move(s)); + if (!is_thread_invariant) { + --condition_counter_; + } +} + +void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) { + bool is_thread_invariant = IsThreadInvariant(op->condition); + if (!is_thread_invariant) { + ++condition_counter_; + } + this->VisitExpr(op->condition); + scope_.push_back(std::vector()); + this->VisitStmt(op->body); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + scope_.back().emplace_back(std::move(s)); + if (!is_thread_invariant) { + --condition_counter_; + } +} + +void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { + // Mark async TMA load context so that tvm_access_ptr within the call + // can be tagged accordingly. + auto is_tma_load = [&]() { + if (auto opt = op->op.as()) { + const Op &call_op = opt.value(); + return call_op.same_as(tl::tma_load()) || + call_op.same_as(tl::tma_load_im2col()); + } + return false; + }(); + if (is_tma_load) { + tma_depth_++; + for (const auto &a : op->args) { + this->VisitExpr(a); + } + tma_depth_--; + return; + } + if (op->op.same_as(builtin::address_of())) { + ICHECK_EQ(op->args.size(), 1U); + if (auto load = op->args[0].as()) { + Buffer buffer = load->buffer; + DataType dtype = buffer->dtype; + const VarNode *buffer_var = buffer->data.as(); + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buffer_var), buffer); + StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); + Array buffer_ranges; + // from indices to buffer indices + ICHECK(buffer->shape.size() == load->indices.size()); + // Use buffer shape and indices to compute the buffer_ranges for each + // dimension. + for (size_t i = 0; i < buffer->shape.size(); ++i) { + PrimExpr min = load->indices[i]; + PrimExpr extent = make_const(buffer->shape[i].dtype(), 1); + buffer_ranges.push_back(Range::FromMinExtent(min, extent)); + } + if (Enabled(buffer_var, scope)) { + ICHECK(allow_append_); + AccessEntry e; + e.threads = env_threads(); + e.thread_range = this->ComputeThreadRange(e.threads); + e.dtype = dtype; + e.buffer = Downcast(buffer->data); + e.buffer_ranges = buffer_ranges; + for (const auto &index : load->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } + e.is_pointer_access = true; + e.type = kRead; + e.scope = scope; + curr_stmt_.access.emplace_back(e); + } + IRVisitorWithAnalyzer::VisitExpr_(load); + } else { + IRVisitorWithAnalyzer::VisitExpr_(op); + } + } else if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + const VarNode *buffer_var = op->args[1].as(); + PrimExpr offset = op->args[2]; + PrimExpr extent = op->args[3]; + const IntImmNode *flag = op->args[4].as(); + StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); + // The buffer scope. + if (Enabled(buffer_var, scope)) { + ICHECK(allow_append_); + Array buffer_ranges; + if (buffer_data_to_buffer_.find(tvm::ffi::GetRef(buffer_var)) == + buffer_data_to_buffer_.end()) { + // cannot find buffer map, use the default buffer + buffer_ranges = {Range::FromMinExtent(offset, extent)}; + } else { + Buffer buffer = + buffer_data_to_buffer_.at(tvm::ffi::GetRef(buffer_var)); + auto buffer_shape = buffer->shape; + // convert 1d offset to multi-dimensional index + auto linear_to_indices = [this](PrimExpr offset, + const Array &shape) { + Array indices; + PrimExpr remaining = std::move(offset); + for (size_t i = 0; i < shape.size(); ++i) { + PrimExpr stride = make_const(DataType::Int(32), 1); + for (size_t j = i + 1; j < shape.size(); ++j) { + stride = stride * shape[j]; + } + PrimExpr idx = FloorDiv(remaining, stride); + remaining = FloorMod(remaining, stride); + indices.push_back(analyzer_.Simplify(idx)); + } + return indices; + }; + Array start_indices = linear_to_indices(offset, buffer_shape); + Array end_indices = + linear_to_indices(offset + extent, buffer_shape); + for (size_t i = 0; i < buffer_shape.size(); ++i) { + buffer_ranges.push_back(Range::FromMinExtent( + start_indices[i], + analyzer_.Simplify(end_indices[i] - start_indices[i]))); + } + } + AccessEntry e; + e.threads = env_threads(); + e.thread_range = this->ComputeThreadRange(e.threads); + e.dtype = dtype; + e.buffer = tvm::ffi::GetRef(buffer_var); + e.buffer_ranges = buffer_ranges; + e.is_pointer_access = true; + e.touched = { + arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; + e.scope = scope; + if (flag->value & 1) { + e.type = kRead; + e.is_async_copy = (tma_depth_ > 0); + curr_stmt_.access.emplace_back(e); + } + if (flag->value & 2) { + e.type = kWrite; + e.is_async_copy = (tma_depth_ > 0); + curr_stmt_.access.emplace_back(e); + } + } + IRVisitorWithAnalyzer::VisitExpr_(op); + } else if (op->op.same_as(builtin::tvm_storage_sync())) { + ICHECK(allow_append_); + const std::string &s = op->args[0].as()->value; + if (s != "warp") { + StorageScope scope = StorageScope::Create(s); + AccessEntry e; + e.threads = env_threads(); + e.thread_range = this->ComputeThreadRange(e.threads); + e.type = kSync; + e.scope = StorageScope::Create(s); + curr_stmt_.access.emplace_back(std::move(e)); + } + } else { + IRVisitorWithAnalyzer::VisitExpr_(op); + } +} + +Map TileLangStorageAccessVisitor::ComputeThreadRange( + const Array &threads) { + Map thread_range; + for (const auto &th : threads) { + auto thread_tag = th->thread_tag; + if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || + thread_tag == "threadIdx.z") { + auto const_int_bound = analyzer_.const_int_bound(th->var); + auto min_value = const_int_bound->min_value; + auto max_value = const_int_bound->max_value; + auto extent = max_value - min_value + 1; + auto dtype = th->var.dtype(); + thread_range.Set(th->var, Range::FromMinExtent(IntImm(dtype, min_value), + IntImm(dtype, extent))); + } + } + return thread_range; +} + +StorageScope +TileLangStorageAccessVisitor::GetScope(const Var &buffer_var) const { + if (buffer_var->type_annotation.as()) { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); + } + return StorageScope(); // global by default +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/storage_access.h b/tilelang/original/src/transform/storage_access.h new file mode 100644 index 0000000000000000000000000000000000000000..54114ace24d4568dcb9e8db666cba78da6f3d919 --- /dev/null +++ b/tilelang/original/src/transform/storage_access.h @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file storage_access.h + * \brief Common data structure for storage access analysis. + */ +#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ +#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ + +#include +#include +#include +#include + +#include +#include + +#include "arith/ir_visitor_with_analyzer.h" +#include "runtime/thread_storage_scope.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace ffi; +using arith::IRVisitorWithAnalyzer; +using runtime::StorageRank; +using runtime::StorageScope; + +/*! + * \brief Base class of storage access analysis + */ +class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer { +public: + /*! \brief Storage access type */ + enum AccessType : uint8_t { + kRead, + kWrite, + kSync, + kAlloc, + // acquired version of read, only need to handle WAR dep. + kReadAcquire + }; + /*! \brief An access entry */ + struct AccessEntry { + /*! \brief The thread index that access this entry */ + Array threads; + /*! \brief The touched thread range */ + Map thread_range; + /*! \brief The buffer variable, if any */ + Array buffer_indices; + /*! \brief The buffer ranges for pointer access */ + Array buffer_ranges; + Var buffer = NullValue(); + /*! \brief The access data type */ + DataType dtype; + /*! \brief The touched access range + * + * Has one IntSet for each index in the buffer being accessed. + */ + Array touched; + /*! \brief The type of access */ + AccessType type; + /*! \brief The storage scope */ + StorageScope scope; + /*! \brief Whether the access is double buffer write */ + bool double_buffer_write = false; + /*! \brief Whether the access is pointer access */ + bool is_pointer_access = false; + /*! \brief Whether this access originates from an async copy context + * (e.g., inside a TMA load) and therefore multiple writes + * among themselves should not force barriers between them. */ + bool is_async_copy = false; + }; + + /*! \brief Access pattern about a single statement */ + struct StmtEntry { + /*! \brief The statement */ + const Object *stmt{}; + /*! \brief access patterns in the statement */ + std::vector access; + }; + // override visitor pattern + void VisitExpr_(const BufferLoadNode *op) final; + void VisitStmt_(const BufferStoreNode *op) final; + void VisitStmt_(const EvaluateNode *op) final; + void VisitStmt_(const LetStmtNode *op) final; + void VisitStmt_(const AttrStmtNode *op) override; + void VisitStmt_(const ForNode *op) final; + void VisitStmt_(const IfThenElseNode *op) final; + void VisitStmt_(const WhileNode *op) final; + void VisitExpr_(const CallNode *op) final; + void VisitStmt_(const BlockNode *op) final; + + void SetBufferDataToBuffer(const Var &buffer_var, const Buffer &buffer) { + buffer_data_to_buffer_.Set(buffer_var, buffer); + } + +protected: + TileLangStorageAccessVisitor() { scope_.push_back(std::vector()); } + /*! \return number of conditions in the current scope. */ + int condition_counter() const { return condition_counter_; } + /*! \return whether we are in device environment. */ + bool in_device_env() const { return in_device_env_; } + /*! \return environment threads */ + const Array &env_threads() const { return env_threads_; } + /*! + * \brief Whether we need analyze the buffer in current scope. + * \param buffer The buffer to be checked + * \param scope The scope of the buffer. + * \return Whether the analysis of buffer is enabled. + */ + virtual bool Enabled(const VarNode *buffer, const StorageScope &scope) const { + return true; + } + /*! + * \brief Summarize the sequence of operations into parent. + * + * Insert synchronization if necessary and remove un-necessary + * memory access which are already synced. + * + * \param seq The sequence of the access operations. + * \param loop Pass loop node if it is a loop, otherwise nullptr. + * \return The summarized sequence that represent access that + * the parent should taken care of to synchronize. + */ + virtual std::vector Summarize(std::vector seq, + const ForNode *loop) = 0; + + /*! + * \brief Compute the thread range for the given threads. + * \param threads The threads to compute the range for. + * \return The thread range. + */ + Map ComputeThreadRange(const Array &threads); + + /*! + * \brief Get the scope of the buffer array. + * \return The scope of the final buffer array. + */ + StorageScope GetScope(const Var &buffer_var) const; + // access scope + std::vector> scope_; + +private: + // whether access appending is enabled. + bool allow_append_{false}; + // Whether we are in device environment + bool in_device_env_{false}; + // Nesting depth of tma_load/tma_load_im2col calls + int tma_depth_{0}; + // Whether we are inside condition. + int condition_counter_{0}; + // The current double buffer write scope. + const VarNode *double_buffer_write_{nullptr}; + // the current free stmt entry. + StmtEntry curr_stmt_; + // The involving threads + Array env_threads_; + // The buffer map + Map buffer_data_to_buffer_; +}; +} // namespace tl +} // namespace tvm +#endif // TVM_TL_TRANSFORMS_STORAGE_ACCESS_H_ diff --git a/tilelang/original/src/transform/storage_rewrite.cc b/tilelang/original/src/transform/storage_rewrite.cc new file mode 100644 index 0000000000000000000000000000000000000000..40973f39abffe58105766a47887271e1934f3aae --- /dev/null +++ b/tilelang/original/src/transform/storage_rewrite.cc @@ -0,0 +1,2029 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file storage_rewrite.cc + * \brief Memory access pattern analysis and optimization. + * Re-write data access to enable memory sharing when possible. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "arith/int_operator.h" +#include "runtime/thread_storage_scope.h" +#include "tir/ir/buffer_common.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using runtime::StorageRank; +using runtime::StorageScope; +using namespace tir; + +/*! + * \brief Perform data type legalization on the given BufferLoadNode pointer. + * Equal to BufferLoadNode::LegalizeDType, but operates on a pointer. + * \param n A pointer to a writable BufferLoadNode. + */ +static void LegalizeBufferLoadDType(BufferLoadNode *n) { + // Check that all indices except the last one have a scalar dtype + for (int i = 0; i < static_cast(n->indices.size()) - 1; i++) { + ICHECK(n->indices[i].dtype().is_scalar()) + << "Only the last index of a buffer access may be a vector type."; + } + + // If there are no indices, set the dtype to the buffer's dtype + if (n->indices.empty()) { + n->dtype = n->buffer->dtype; + } else { + auto index_dtype = n->indices.back().dtype(); + bool is_buffer_dtype_scalable = n->buffer->dtype.is_scalable_vector(); + bool is_index_scalable = index_dtype.is_scalable_vector(); + + // Do not allow both index dtype and buffer dtype to be scalable vectors + ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) + << "Index dtype and buffer dtype cannot both be scalable."; + + if (is_index_scalable) { + // Index is a scalable vector, while the buffer is not + n->dtype = n->buffer->dtype.with_scalable_vscale_factor( + index_dtype.vscale_factor() * n->buffer->dtype.lanes()); + } else if (is_buffer_dtype_scalable) { + // The buffer is a scalable vector, while the index is not + n->dtype = n->buffer->dtype.with_scalable_vscale_factor( + n->buffer->dtype.vscale_factor() * index_dtype.lanes()); + } else { + // Neither side is a scalable vector, multiply lanes + n->dtype = n->buffer->dtype.with_lanes(index_dtype.lanes() * + n->buffer->dtype.lanes()); + } + } +} + +/*! + * \brief collect the mapping from the buffer var to its allocate + */ +class AllocateCollector : public StmtExprVisitor { +private: + bool IsDynamicSharedMemory(Var buffer_var) { + StorageScope storage_scope = runtime::StorageScope::Create( + GetPtrStorageScope(std::move(buffer_var))); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag == ".dyn"; + } + + bool IsStaticSharedMemory(Var buffer_var) { + StorageScope storage_scope = runtime::StorageScope::Create( + GetPtrStorageScope(std::move(buffer_var))); + return storage_scope.rank == runtime::StorageRank::kShared && + storage_scope.tag.empty(); + } + +public: + void VisitStmt_(const AllocateNode *op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_[op->buffer_var.get()] = op; + } else if (IsStaticSharedMemory(op->buffer_var)) { + static_shmem_allocs_[op->buffer_var.get()] = op; + } + StmtExprVisitor::VisitStmt_(op); + } + // The dynamic mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; + // The static mapping from the original buffer var to its allocate + std::unordered_map + static_shmem_allocs_; +}; + +// Find a linear pattern of storage access +// Used for liveness analysis. +// Composite scopes(loop/thread_launch/IfThen) is represented by two points: +// before_scope -> scope_body -> after_scope +// +// The linear_seq_ stores before_scope and after_scope. +// The access to the arrays are stored at the after_scope point. +// +// Define "scope" as the body of For/thread_launch/IfThenElse +// This pass tries to detect last point that we need to keep memory +// alive under the same scope as allocate. +// The storage need to be kept alive between allocate and last access. +// The free point is only inserted at the same scope of allocate. +// +class LinearAccessPatternFinder final : public StmtExprVisitor { +public: + /*! \brief record the touch hist of statement. */ + struct StmtEntry { + // The statement + const Object *stmt{}; + // The index in the linear_seq_ to point to end of the nested scope. + // This is only set to non-zero if stmt is a nested scope. + // if offset > 0, means this is the begin, the end entry is current_index + + // offset if offset < 0, means this is the end, the begin entry is + // current_index + offset + int64_t scope_pair_offset{0}; + // The buffer variables this statement touched. + std::vector touched; + }; + // The scope of each allocation + struct AllocEntry { + // The physical dimension of the allocation. + size_t num_physical_dimensions{0}; + // scope level + size_t level{0}; + // allocation stmt + const AllocateNode *alloc{nullptr}; + }; + + void VisitStmt_(const AllocateNode *op) final { + size_t level = scope_.size(); + const VarNode *buf = op->buffer_var.get(); + + AllocEntry entry; + entry.alloc = op; + entry.level = level; + // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer, + // all allocations specify the extent of physical dimensions, and + // is 1 for flat memory spaces. + entry.num_physical_dimensions = op->extents.size(); + alloc_info_[buf] = entry; + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + all_buffers_accessed_.insert(op->buffer.get()); + + // Add write access. + const VarNode *buffer_var = op->buffer->data.get(); + auto it = alloc_info_.find(buffer_var); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + scope_[it->second.level].touched.push_back(buffer_var); + + ICHECK_EQ(op->buffer->axis_separators.size() + 1, + it->second.num_physical_dimensions) + << "Buffer " << op->buffer->name << " is allocated with " + << it->second.num_physical_dimensions + << " physical dimensions, but is accessed as having " + << op->buffer->axis_separators.size() + 1 << " physical dimensions" + << '\n'; + } + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (!e.touched.empty()) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + + void VisitExpr_(const BufferLoadNode *op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + + all_buffers_accessed_.insert(op->buffer.get()); + + const VarNode *buffer_var = op->buffer->data.get(); + auto it = alloc_info_.find(buffer_var); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) + << "Load memory in places other than store."; + scope_[it->second.level].touched.push_back(buffer_var); + + ICHECK_EQ(op->buffer->axis_separators.size() + 1, + it->second.num_physical_dimensions) + << "Buffer " << op->buffer->name << " is allocated with " + << it->second.num_physical_dimensions + << " physical dimensions, but is accessed as having " + << op->buffer->axis_separators.size() + 1 << " physical dimensions" + << '\n'; + } + } + + void VisitStmt_(const EvaluateNode *op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (!e.touched.empty()) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + + void VisitExpr_(const VarNode *buf) final { + // Directly reference to the variable count as a read. + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; + scope_[it->second.level].touched.push_back(buf); + } + } + + template void VisitNewScope(const T *op) { + scope_.push_back(StmtEntry()); + StmtEntry e; + e.stmt = op; + int64_t begin_index = static_cast(linear_seq_.size()); + // before scope. + linear_seq_.push_back(e); + StmtExprVisitor::VisitStmt_(op); + // after scope. + e.touched = std::move(scope_.back().touched); + scope_.pop_back(); + int64_t end_index = static_cast(linear_seq_.size()); + ICHECK_GT(end_index, begin_index); + e.scope_pair_offset = begin_index - end_index; + linear_seq_.push_back(e); + // record the pointer to end index. + ICHECK_NE(end_index, 0U); + linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; + } + + void VisitStmt_(const AttrStmtNode *op) final { + // Only record the outer most thread extent. + if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) { + in_thread_env_ = true; + VisitNewScope(op); + in_thread_env_ = false; + } else if (op->attr_key == tir::attr::extern_scope) { + VisitNewScope(op); + } else if (op->attr_key == tir::attr::virtual_thread) { + VisitNewScope(op); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + + void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const ForNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const AssertStmtNode *op) final { VisitNewScope(op); } + + void VisitStmt_(const LetStmtNode *op) final { VisitNewScope(op); } + + // linearized access sequence. + std::vector linear_seq_; + // The storage scope of each buffer + std::unordered_map alloc_info_; + // A record of which Buffer objects have been accessed, to prune + // unused DeclBuffer instances. + std::unordered_set all_buffers_accessed_; + +private: + // Whether already in thread env. + bool in_thread_env_{false}; + // The scope stack. + std::vector scope_; +}; + +// Verify if the statement can be run safely via inplace fashion +// +// Detect pattern: dst[index] = f(src[index]) +// +// WARNING: the current detection algorithm cannot handle the case +// when a location in an array is written multiple times +// +// For example, the following program will pass the check, +// but we cannot make A and B to be the same array. +// +// A[0] = B[0] + 1 +// A[0] = B[0] + 1 +// +// The high level code generator needs to ensure that the generated +// code only write each location of the target array once. +// +// This is the case with IR generated by the current compute schedule. +// We explicitly return false if we find there is an extern block +// which can be arbitrary IR. +// +// Neve-the-less, inplace detector should be used with care in mind. +// We may also consider introduce a condition checker that checks +// if every index only visited once for an absolute sufficient condition. +// +// The code after inplace transformation is no longer idempotent. +// +class InplaceOpVerifier : public StmtExprVisitor { +public: + bool Check(const Object *stmt, const VarNode *dst, const VarNode *src) { + dst_ = dst; + src_ = src; + result_ = true; + if (stmt->IsInstance()) { + VisitStmt_(reinterpret_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(reinterpret_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(reinterpret_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(reinterpret_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(reinterpret_cast(stmt)); + } else { + return false; + } + return result_; + } + + using StmtExprVisitor::VisitStmt_; + + void VisitStmt(const Stmt &n) final { + if (!result_) + return; + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr(const PrimExpr &n) final { + if (!result_) + return; + StmtExprVisitor::VisitExpr(n); + } + + void VisitExpr_(const VarNode *op) final { + // assume all opaque access is unsafe + if (op == dst_ || op == src_) { + result_ = false; + return; + } + } + + void VisitStmt_(const BufferStoreNode *op) final { + ++mem_nest_; + for (const auto &index : op->indices) { + this->VisitExpr(index); + } + --mem_nest_; + if (op->buffer->data.get() == dst_) { + store_ = op; + this->VisitExpr(op->value); + store_ = nullptr; + } else { + this->VisitExpr(op->value); + } + } + + void VisitStmt_(const AttrStmtNode *op) final { + // always reject extern code + if (op->attr_key == tir::attr::extern_scope || + op->attr_key == tir::attr::volatile_scope) { + result_ = false; + return; + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode *op) final { + const VarNode *buf = op->buffer->data.get(); + // cannot read from dst_ (no reduction) + if (buf == dst_) { + result_ = false; + return; + } + // do not allow indirect memory load + if (mem_nest_ != 0) { + result_ = false; + return; + } + if (src_ == buf) { + if (store_ == nullptr || store_->value.dtype() != op->dtype) { + result_ = false; + return; + } + ICHECK_EQ(store_->indices.size(), op->indices.size()) + << "Store/Load occur to the same buffer " << buf->name_hint + << " with differing number of indices"; + for (size_t i = 0; i < store_->indices.size(); i++) { + if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) { + result_ = false; + return; + } + } + } + ++mem_nest_; + StmtExprVisitor::VisitExpr_(op); + --mem_nest_; + } + +private: + // result of the check + bool result_{true}; + // destination memory + const VarNode *dst_{}; + // source variable + const VarNode *src_{}; + // counter of load, + // it is not safe to inplace when there is nested load like A[B[i]] + int mem_nest_{0}; + // The current store to be inspected + const BufferStoreNode *store_{nullptr}; +}; + +/* \brief Rewrite and merge memory allocation. + * + * Using LinearAccessPatternFinder, determines which buffers could share an + * allocation. This includes both sequential usage of the same buffer and + * merging small allocations at the same scope into a single larger allocation. + * The merging of small allocations requires the codegen to cast the resulting + * value from the storage type to the output type after access. + */ +class StoragePlanRewriter : public StmtExprMutator { +public: + using StmtEntry = LinearAccessPatternFinder::StmtEntry; + using AllocEntry = LinearAccessPatternFinder::AllocEntry; + + Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse, + bool reuse_require_exact_matched_dtype, + Map local_var_init_map = {}) { + detect_inplace_ = detect_inplace; + local_var_init_map_ = std::move(local_var_init_map); + // plan the rewrite + LinearAccessPatternFinder finder; + finder(stmt); + this->LivenessAnalysis(finder.linear_seq_); + this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse, + reuse_require_exact_matched_dtype); + all_buffers_accessed_ = finder.all_buffers_accessed_; + this->PrepareNewAlloc(); + // start rewrite + stmt = operator()(std::move(stmt)); + if (attach_map_.count(nullptr)) { + return MakeAttach(attach_map_.at(nullptr), stmt); + } + return stmt; + } + + template Node VisitBufferAccess(Node node) { + auto it = alloc_map_.find(node->buffer->data.get()); + if (it != alloc_map_.end()) { + Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); + + Array indices = node->indices; + indices.Set(indices.size() - 1, + RemapIndex(node->buffer->dtype, indices[indices.size() - 1], + it->second)); + + auto writer = node.CopyOnWrite(); + writer->buffer = buf; + writer->indices = indices; + } + return node; + } + + Buffer RemapBuffer(const Buffer &buf, const Var &new_backing_array) { + auto key = buf.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + ICHECK_EQ(it->second->data.get(), new_backing_array.get()) + << "Cannot remap buffer " << buf->name << " to use backing array " + << new_backing_array->name_hint << ", previously used backing array " + << it->second->data->name_hint; + return it->second; + } + + Buffer remapped = Buffer( + new_backing_array, buf->dtype, buf->shape, buf->strides, + buf->elem_offset, new_backing_array->name_hint, buf->data_alignment, + buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); + buffer_remap_[key] = remapped; + return remapped; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const VarNode *op) final { + auto it = alloc_map_.find(op); + if (it != alloc_map_.end()) { + if (it->second->bits_offset != 0) { + LOG(WARNING) + << "Use a merged buffer variable address, could cause error"; + } + return it->second->alloc_var; + } else { + return tvm::ffi::GetRef(op); + } + } + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + const VarNode *buffer = op->args[1].as(); + auto it = alloc_map_.find(buffer); + if (it == alloc_map_.end()) { + return StmtExprMutator::VisitExpr_(op); + } + const StorageEntry *se = it->second; + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); + uint64_t elem_bits = dtype.bits() * dtype.lanes(); + ICHECK_EQ(se->bits_offset % elem_bits, 0U); + if (se->bits_offset != 0) { + offset = + make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; + } + return Call(op->dtype, op->op, + {op->args[0], se->alloc_var, offset, extent, op->args[4]}); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent || + op->attr_key == tir::attr::virtual_thread || + tir::attr::IsPragmaKey(op->attr_key)) { + // remake all the allocation at the attach scope. + if (attach_map_.count(op)) { + auto &svec = attach_map_[op]; + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + return AttrStmt(op->node, op->attr_key, op->value, + MakeAttach(svec, op->body)); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } else if (op->attr_key == tir::attr::volatile_scope) { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + auto it = alloc_map_.find(op->node.as()); + if (it == alloc_map_.end()) + return stmt; + return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const ForNode *op) final { + ICHECK(op->kind != ForKind::kVectorized) + << "VectorizeLoop before LiftStorageAlloc"; + // remake all the allocation at the attach scope. + if (attach_map_.count(op)) { + auto &svec = attach_map_[op]; + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + return For(op->loop_var, op->min, op->extent, op->kind, + MakeAttach(svec, op->body), op->thread_binding, + op->annotations); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const AllocateNode *op) final { + return this->VisitStmt(op->body); + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + if (hoisted_buffer_decls_.count(op->buffer.get()) || + !all_buffers_accessed_.count(op->buffer.get())) { + return this->VisitStmt(op->body); + } + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + if (auto it = alloc_map_.find(op->buffer->data.get()); + it != alloc_map_.end()) { + Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var); + node.CopyOnWrite()->buffer = buf; + } + return std::move(node); + } + +private: + struct StorageEntry { + // The scope that this alloc attaches after + // For shared/local memory it is beginning of the thread extent. + // for global memory it is nullptr, means beginning of everything. + const Object *attach_scope_{nullptr}; + // The constant size of the buffer in bits, only used if it is constant + uint64_t const_nbits{0}; + // The storage scope. + StorageScope scope; + // The physical dimensionality of the allocations. Since + // StorageRewrite is applied after StorageFlatten/FlattenBuffer, + // this is size of `AllocateNode::extents`. If moved + size_t ndim{}; + // Allocs that shares this entry. + std::vector allocs; + // The children of this entry, not including itself. + std::vector merged_children; + // The replacement Allocate, if any. May also include associated + // DeclBuffer statement. + std::vector alloc_nest; + // The var expr of new allocation. + Var alloc_var; + // The allocation element type. + DataType elem_type; + // This is non-zero if this allocate is folded into another one + // the address(in bits) becomes alloc_var + bits_offset; + // can be effectively converted to the element type. + // We need to convert bit_offset to offset of specific element type later. + // + // We use bits(instead of bytes) to support non-conventional indexing in + // hardware. When we are merging buffer together, the bits_offset are set to + // be aligned to certain value given by the max_simd_bits property of the + // special memory. + // + // This allows effective sharing among different types as long as their + // alignment requirement fits into the max_simd_bits. + uint64_t bits_offset{0}; + }; + + // Checks whether the storage_scope is especially tagged for a specific + // memory. Special memory is all combined into a single allocation. + bool IsSpecialTaggedMemory(const StorageScope &scope) { + return !scope.tag.empty() && scope.tag != ".dyn" && + scope.tag != ".barrier" && scope.tag != ".workspace" && + scope.tag != ".vtcm" && scope.tag != ".var" && + scope.tag.find(".descriptor") != 0; + } + + // Allocate entry of node. + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + + Stmt MakeAttach(const std::vector &svec, Stmt body) { + for (auto it = svec.rbegin(); it != svec.rend(); it++) { + body = MergeNest((*it)->alloc_nest, body); + } + return body; + } + Map MakeAllocateAnnotations(const Var &buffer_var) const { + Map annotations; + if (local_var_init_map_.defined()) { + auto it = local_var_init_map_.find(buffer_var); + if (it != local_var_init_map_.end()) { + const PrimExpr &init = (*it).second; + annotations.Set(tl::attr::kLocalVarInit, init); + } + } + return annotations; + } + // Remap the index + PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) { + if (e->bits_offset == 0) + return index; + uint64_t elem_bits = dtype.bits(); + ICHECK_EQ(e->bits_offset % elem_bits, 0U); + return make_const(index.dtype(), e->bits_offset / elem_bits) + index; + } + // Prepare the new allocations + void PrepareNewAlloc() { + for (size_t i = 0; i < alloc_vec_.size(); ++i) { + StorageEntry *e = alloc_vec_[i].get(); + attach_map_[e->attach_scope_].push_back(e); + } + // find allocation via attach map. + for (auto &kv : attach_map_) { + // find the element with the most amount of bytes. + std::vector &vec = kv.second; + // try to find merge, for tagged memory + for (size_t i = 0; i < vec.size(); ++i) { + StorageEntry *e = vec[i]; + if (IsSpecialTaggedMemory(e->scope)) { + ICHECK_NE(e->const_nbits, 0U) + << "Special tagged memory must be const size"; + for (size_t j = 0; j < i; ++j) { + if (e->scope == vec[j]->scope) { + vec[j]->merged_children.push_back(e); + break; + } + } + } + } + // Start allocation + for (size_t i = 0; i < vec.size(); ++i) { + StorageEntry *e = vec[i]; + // already merged + if (e->bits_offset != 0) + continue; + if (!e->merged_children.empty()) { + NewAllocTagMerged(e); + continue; + } + // Get the allocation size; + e->alloc_var = e->allocs[0]->buffer_var; + DataType alloc_type = e->allocs[0]->dtype; + for (const AllocateNode *op : e->allocs) { + if (op->dtype.lanes() > alloc_type.lanes()) { + alloc_type = op->dtype; + } + } + + bool all_allocs_identical = std::all_of( + e->allocs.begin() + 1, e->allocs.end(), + [&](const AllocateNode *op) -> bool { + const AllocateNode *first = *e->allocs.begin(); + if (op->dtype != first->dtype) { + return false; + } + if (op->extents.size() != first->extents.size()) { + return false; + } + ExprDeepEqual expr_equal; + for (size_t i = 0; i < op->extents.size(); i++) { + if (!expr_equal(op->extents[i], first->extents[i])) { + return false; + } + } + return true; + }); + + if (all_allocs_identical) { + // simply use the original allocation. + Map annotations = + MakeAllocateAnnotations(e->alloc_var); + e->alloc_nest.push_back(Allocate( + e->alloc_var, alloc_type, e->allocs[0]->extents, + e->allocs[0]->condition, Evaluate(0), std::move(annotations))); + if (auto ptr = e->allocs[0]->body.as()) { + e->alloc_nest.push_back(DeclBuffer( + RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0))); + hoisted_buffer_decls_.insert(ptr->buffer.get()); + } + if (IsSpecialTaggedMemory(e->scope)) { + MemoryInfo info = GetMemoryInfo(e->scope.to_string()); + if (info.defined()) { + uint64_t total_elem = e->const_nbits / e->elem_type.bits(); + ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) + << "Allocation exceed bound of memory tag " + << e->scope.to_string(); + } + } + } else { + // Build a merged allocation + PrimExpr combo_size; + for (const AllocateNode *op : e->allocs) { + ICHECK_EQ(op->extents.size(), 1) + << "Buffer var " << op->buffer_var->name_hint + << " was identified as a reusable allocation, but has " + << op->extents.size() << " physical dimensions. " + << "Currently, only flat 1-d memory spaces should be " + "identified as reusable " + "allocations."; + PrimExpr sz = op->extents[0]; + auto nbits = op->dtype.bits() * op->dtype.lanes(); + if (const auto *imm = sz.as()) { + if (imm->value > std::numeric_limits::max() / nbits) { + LOG(WARNING) << "The allocation requires : " << imm->value + << " * " << nbits + << " bits, which is greater than the maximum of" + " int32. The size is cast to int64." + << "\n"; + sz = make_const(DataType::Int(64), imm->value); + } + } + // transform to bits + auto sz_nbits = sz * nbits; + if (combo_size.defined()) { + combo_size = max(combo_size, sz_nbits); + } else { + combo_size = sz_nbits; + } + } + // transform to alloc bytes + auto type_bits = alloc_type.bits() * alloc_type.lanes(); + bool divided = + analyzer_.CanProve(indexmod(combo_size, type_bits) == 0); + combo_size = indexdiv(combo_size, type_bits); + // round up for can not divided + if (!divided) { + combo_size = combo_size + make_const(DataType::Int(32), 1); + } + combo_size = analyzer_.Simplify(combo_size); + Map annotations = + MakeAllocateAnnotations(e->alloc_var); + e->alloc_nest.push_back( + Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), + Evaluate(0), std::move(annotations))); + if (IsSpecialTaggedMemory(e->scope)) { + MemoryInfo info = GetMemoryInfo(e->scope.to_string()); + if (info.defined()) { + uint64_t total_elem = e->const_nbits / e->elem_type.bits(); + ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) + << "Allocation exceed bound of memory tag " + << e->scope.to_string(); + } + } + } + } + } + } + // New allocation for merged data + void NewAllocTagMerged(StorageEntry *e) { + ICHECK_NE(e->scope.tag.length(), 0U); + // allocate with element type. + ICHECK_NE(e->const_nbits, 0U); + MemoryInfo info; + if (e->scope.tag != ".barrier" && e->scope.tag != ".var" && + e->scope.tag.find(".descriptor") != 0) { + info = GetMemoryInfo(e->scope.to_string()); + } + uint64_t total_bits = e->const_nbits; + // By default, align to 32 bits. + size_t align = 32; + if (info.defined()) { + align = info->max_simd_bits; + } + // Always align to max_simd_bits + // so we can remap types by keeping this property + if (total_bits % align != 0) { + total_bits += align - (total_bits % align); + } + e->alloc_var = e->allocs[0]->buffer_var; + for (StorageEntry *child : e->merged_children) { + ICHECK_NE(child->const_nbits, 0U); + ICHECK_NE(total_bits, 0U); + child->bits_offset = total_bits; + child->alloc_var = e->alloc_var; + total_bits += child->const_nbits; + if (total_bits % align != 0) { + total_bits += align - (total_bits % align); + } + } + uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); + PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), + (total_bits + type_bits - 1) / type_bits); + Map annotations = MakeAllocateAnnotations(e->alloc_var); + e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size}, + const_true(), Evaluate(0), + std::move(annotations))); + if (info.defined()) { + ICHECK_LE(total_bits, info->max_num_bits) + << "Allocation exceed bound of memory tag " << e->scope.to_string(); + } + } + // Liveness analysis to find gen and kill point of each variable. + void LivenessAnalysis(const std::vector &seq) { + // find kill point, do a reverse linear scan. + std::unordered_set touched; + for (size_t i = seq.size(); i != 0; --i) { + const StmtEntry &s = seq[i - 1]; + for (const VarNode *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].kill.push_back(buffer); + } + } + } + // find gen point, do forward scan + touched.clear(); + for (size_t i = 0; i < seq.size(); ++i) { + int64_t offset = seq[i].scope_pair_offset; + if (offset < 0) + continue; + const StmtEntry &s = seq[i + offset]; + for (const VarNode *buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].gen.push_back(buffer); + } + } + } + } + void PlanNewScope(const Object *op) { + if (thread_scope_ != nullptr) { + ICHECK(thread_scope_ == op); + // erase all memory attached to this scope. + for (auto it = const_free_map_.begin(); it != const_free_map_.end();) { + if (it->second->attach_scope_ == op) { + it = const_free_map_.erase(it); + } else { + ++it; + } + } + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) { + if ((*it)->attach_scope_ == op) { + it = sym_free_list_.erase(it); + } else { + ++it; + } + } + thread_scope_ = nullptr; + } else { + thread_scope_ = op; + } + } + + // Memory plan algorithm + void + PlanMemory(const std::vector &seq, + const std::unordered_map &alloc_info, + bool enable_reuse, bool reuse_require_exact_matched_dtype) { + std::unordered_set inplace_flag; + + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry &s = seq[i]; + auto it = event_map_.find(seq[i].stmt); + + // scope_pair_offset >= 0 means it is either + // - leaf stmt(offset = 0) + // - beginning of scope(offset < 0) + // In both cases, we need to handle the gen event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { + // Inplace operation detection + // specially handle this + bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2); + + for (const VarNode *var : it->second.gen) { + ICHECK(alloc_info.count(var)); + const AllocEntry &entry = alloc_info.at(var); + const AllocateNode *alloc = entry.alloc; + auto storage_scope = StorageScope::Create( + GetPtrStorageScope(tvm::ffi::GetRef(var))); + StorageEntry *dst_entry = nullptr; + // inplace detection + if (detect_inplace) { + // only one inplace var for s.stmt + bool inplace_found = false; + for (const VarNode *src : it->second.kill) { + if (!inplace_flag.count(src) && alloc_map_.count(src)) { + InplaceOpVerifier visitor; + StorageEntry *src_entry = alloc_map_.at(src); + if (src_entry->scope == storage_scope && + src_entry->attach_scope_ == thread_scope_ && + src_entry->elem_type == alloc->dtype.element_of() && + visitor.Check(s.stmt, var, src)) { + uint64_t const_nbits = + static_cast(alloc->ConstantAllocationSize()) * + alloc->dtype.bits() * alloc->dtype.lanes(); + if (src_entry->const_nbits == const_nbits && !inplace_found) { + // successfully inplace + dst_entry = src_entry; + inplace_flag.insert(src); + inplace_found = true; + } + } + } + } + } + if (dst_entry == nullptr) { + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope, + entry.num_physical_dimensions, enable_reuse, + reuse_require_exact_matched_dtype); + } + dst_entry->allocs.emplace_back(alloc); + alloc_map_[var] = dst_entry; + } + } + // enter/exit new scope + if (s.stmt->IsInstance()) { + const auto *op = reinterpret_cast(s.stmt); + if (op->attr_key == tir::attr::thread_extent || + op->attr_key == tir::attr::virtual_thread || + tir::attr::IsPragmaKey(op->attr_key)) { + PlanNewScope(op); + } else { + ICHECK(op->attr_key == tir::attr::extern_scope); + } + } else if (s.stmt->IsInstance()) { + const auto *op = reinterpret_cast(s.stmt); + if (op->kind == ForKind::kParallel) { + if (thread_scope_ == nullptr || thread_scope_ == op) { + PlanNewScope(op); + } + } + } + // scope_pair_offset <= 0 means it is either + // - leaf stmt(offset = 0) + // - end of scope(offset < 0) + // In both cases, we need to handle the kill event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode *var : it->second.kill) { + // skip space which are already replaced by inplace + if (!inplace_flag.count(var)) { + this->Free(var); + } + } + } + } + } + // Allocate new storage entry. + StorageEntry *NewAlloc(const AllocateNode *op, const Object *attach_scope, + const StorageScope &scope, size_t const_nbits) { + ICHECK(op != nullptr); + // Reuse not successful, allocate a new buffer. + auto entry = std::make_unique(); + entry->attach_scope_ = attach_scope; + entry->scope = scope; + entry->elem_type = op->dtype.element_of(); + entry->const_nbits = const_nbits; + StorageEntry *e = entry.get(); + alloc_vec_.emplace_back(std::move(entry)); + return e; + } + + StorageEntry *FindAlloc(const AllocateNode *op, const Object *attach_scope, + const StorageScope &scope, + size_t num_physical_dimensions, bool enable_reuse, + bool reuse_require_exact_matched_dtype) { + ICHECK(op != nullptr); + // skip plan for local variable, + // compiler can do a better job with register allocation. + const uint64_t match_range = 16; + uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); + uint64_t const_nbits = + static_cast(op->ConstantAllocationSize() * op_elem_bits); + + // If the size of the array isn't known at compile-time, it must + // have its own allocation with size determined at runtime. + bool is_known_size = (const_nbits != 0); + + // Currently, only flat memory spaces can be reused. Packing + // into N-d space (e.g. 2-d texture memory on GPUs) will require + // more in-depth algorithms. + bool is_flat_memory_space = (num_physical_dimensions == 1); + + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + bool is_small_array = + (scope.tag.empty()) && + (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() || + (is_known_size && const_nbits <= 32)); + + if (!enable_reuse || is_small_array || !is_flat_memory_space) { + return NewAlloc(op, attach_scope, scope, const_nbits); + } + + if (is_known_size) { + // constant allocation. + auto begin = const_free_map_.lower_bound(const_nbits / match_range); + auto mid = const_free_map_.lower_bound(const_nbits); + auto end = const_free_map_.upper_bound(const_nbits * match_range); + // start looking at the buffer that is bigger than the required size first + for (auto it = mid; it != end; ++it) { + StorageEntry *e = it->second; + if (e->attach_scope_ != attach_scope) + continue; + if (e->scope != scope) + continue; + // when not divided, no reuse, eg, float4 vs float3 + if (e->bits_offset % op_elem_bits != 0) + continue; + if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { + continue; + } + e->const_nbits = std::max(const_nbits, e->const_nbits); + const_free_map_.erase(it); + return e; + } + // then start looking at smaller buffers. + for (auto it = mid; it != begin;) { + --it; + StorageEntry *e = it->second; + if (e->attach_scope_ != attach_scope) + continue; + if (e->scope != scope) + continue; + if (e->elem_type != op->dtype.element_of()) + continue; + if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { + continue; + } + e->const_nbits = std::max(const_nbits, e->const_nbits); + const_free_map_.erase(it); + return e; + } + } else { + // Simple strategy: round roubin. + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { + StorageEntry *e = *it; + if (e->attach_scope_ != attach_scope) + continue; + if (e->scope != scope) + continue; + if (e->elem_type != op->dtype.element_of()) + continue; + sym_free_list_.erase(it); + return e; + } + } + return NewAlloc(op, attach_scope, scope, const_nbits); + } + // simulated free. + void Free(const VarNode *var) { + auto it = alloc_map_.find(var); + ICHECK(it != alloc_map_.end()); + StorageEntry *e = it->second; + ICHECK_NE(e->allocs.size(), 0U); + + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + if (e->scope.tag.empty()) { + // Disable sharing of local memory. + if (e->scope.rank >= StorageRank::kWarp || + e->allocs[0]->dtype.is_handle()) + return; + // disable reuse of small arrays + if (e->const_nbits > 0 && e->const_nbits <= 32) + return; + } + // normal free. + if (e->const_nbits != 0) { + const_free_map_.insert({e->const_nbits, e}); + } else { + sym_free_list_.push_back(e); + } + } + // thread scope. + const Object *thread_scope_{nullptr}; + // whether enable inplace detection. + bool detect_inplace_{false}; + // Locations of free ops. + std::unordered_map event_map_; + // constant size free map. + std::multimap const_free_map_; + // symbolic free list, for non constant items. + std::list sym_free_list_; + // The allocation attach map + std::unordered_map> attach_map_; + // The allocation assign map + std::unordered_map alloc_map_; + // The allocations + std::vector> alloc_vec_; + // The buffer objects being remapped + std::unordered_map buffer_remap_; + // Buffers whose DeclBuffer has been hoisted to be adjacent to the new + // Allocate location + std::unordered_set hoisted_buffer_decls_; + // Any buffers that is accessed at some point. DeclBuffer instances + // that do not appear in this list may be removed. + std::unordered_set all_buffers_accessed_; + // Initial values for local variable buffers. + Map local_var_init_map_; + // analyzer + arith::Analyzer analyzer_; +}; + +/* Helper struct containing information on how a buffer is declared and used + * + */ +struct BufferVarInfo { + enum DeclarationLocation : uint8_t { + kPrimFuncParam = (1 << 0), + kPrimFuncBufferMap = (1 << 1), + kAllocateNode = (1 << 2), + kAllocateConstNode = (1 << 3), + kLetNode = (1 << 4), + }; + + // The tir::Var that represents this buffer. + Var var; + + // The data type of an element of the buffer. + DataType element_dtype; + + /* The extent of the buffer. + * + * If multidimensional, the extent of the last dimension of the buffer. If + * the size is unknown (e.g. pointer arguments to PrimFunc with no + * corresponding entry in buffer_map), then extent is zero. + */ + PrimExpr extent; + + // Where the buffer was declared + DeclarationLocation declaration_location; + + // When accessed, which element type is it accessed as. This may + // differ both in base type (e.g. int32* cast to float32* after + // packing in StorageRewrite) or in number of lanes (e.g. float16* + // cast to float16x4*). + std::unordered_set access_dtype; + // Data types used for scalar reads. This is used to record vectorized read + // dtypes that can be shuffled for scalar reads when + // rewrite_scalar_read_to_vector_shuffle is enabled. + std::unordered_set scalar_read_dtype; + + DataType get_preferred_dtype() const { + std::unordered_set base_access_dtype; + for (auto dtype : access_dtype) { + base_access_dtype.insert(dtype.element_of()); + } + for (auto dtype : scalar_read_dtype) { + base_access_dtype.insert(dtype.element_of()); + } + // If the array is accessed as multiple base types within a + // function, no point in changing the declared type. CodeGenC can + // handle this with a type-cast prior to indexing. Vulkan will + // raise an error at code-gen time, if a later pass doesn't split + // it out. + if (base_access_dtype.size() != 1) { + return element_dtype; + } + + DataType preferred_base_type = *base_access_dtype.begin(); + + // If there is only one vectorizable size used to access the + // buffer, and if that access size is compatible with the array + // size, then the buffer is vectorizable. In the future, this + // could be improved to allow vectorized buffer access of size + // GCD(*lanes_used), if necessary. + // When there are scalar reads and no writes, access_dtype can be empty and + // we should avoid rewriting. + int preferred_lanes = element_dtype.lanes(); + if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) { + int lanes = access_dtype.begin()->lanes(); + // Check the scalar read dtypes are compatible with the vectorized access + // dtype. + for (auto dtype : scalar_read_dtype) { + if (dtype.lanes() % lanes != 0) { + return element_dtype; + } + } + arith::Analyzer analyzer_; + arith::ModularSet me = analyzer_.modular_set(extent); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + preferred_lanes = lanes; + } + } + + return preferred_base_type.with_lanes(preferred_lanes); + } +}; + +/* Checks whether buffers are accessed as scalar or vector parameters in a + * function. + * + */ +class VectorTypeAccessChecker : public StmtExprVisitor { +public: + /* Constructor + * + * @param params The parameters passed to a PrimFunc + * + * @param buffer_map The buffer_map associated with a PrimFunc + * + * @param allow_untyped_handles If a buffer or pointer variable is + * missing a type annotation, assume that it has the same underlying + * type as it is later accessed, with scalar element types. + */ + VectorTypeAccessChecker(const Array ¶ms, + const Map &buffer_map, + bool allow_untyped_pointers = false, + bool detect_scalar_read_patterns = true) + : allow_untyped_pointers_(allow_untyped_pointers), + detect_scalar_read_patterns_(detect_scalar_read_patterns) { + // If a parameter is in the buffer map, we want to track the + // version in the map. + for (auto it : buffer_map) { + Buffer &buffer = it.second; + Var buffer_var = buffer->data; + DataType dtype = buffer->dtype; + PrimExpr extent = + !buffer->shape.empty() ? buffer->shape[buffer->shape.size() - 1] : 0; + OnArrayDeclaration(buffer_var, dtype, extent, + BufferVarInfo::kPrimFuncParam); + } + + // If a pointer parameter isn't in the buffer map, then we want to + // track the parameter itself. + for (Var buffer_var : params) { + auto pointer_type = GetPointerType(buffer_var->type_annotation); + if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) { + DataType dtype = pointer_type.value(); + PrimExpr extent = 0; + OnArrayDeclaration(buffer_var, dtype, extent, + BufferVarInfo::kPrimFuncBufferMap); + } + } + } + + void VisitExpr_(const BufferLoadNode *op) final { + OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices, + /*is_buffer_load=*/true); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices, + /*is_buffer_load=*/false); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + DataType dtype = op->args[0].dtype(); + const VarNode *buffer = op->args[1].as(); + PrimExpr index = op->args[2]; + OnArrayAccess(dtype, buffer, {index}, false); + } else if (op->op.same_as(builtin::address_of())) { + if (auto load = op->args[0].as()) { + OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices, + /*is_buffer_load=*/false); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const AllocateNode *op) final { + const Array &extents = op->extents; + PrimExpr extent = extents[extents.size() - 1]; + OnArrayDeclaration(op->buffer_var, op->dtype, extent, + BufferVarInfo::kAllocateNode); + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocateConstNode *op) final { + const Array &extents = op->extents; + PrimExpr extent = + !extents.empty() ? extents[extents.size() - 1] : NullValue(); + OnArrayDeclaration(op->buffer_var, op->dtype, extent, + BufferVarInfo::kAllocateConstNode); + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const LetNode *op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const LetStmtNode *op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitStmt_(op); + } + + void HandleLetNode(const Var &let_var) { + if (let_var->dtype.is_handle()) { + auto pointer_type = GetPointerType(let_var->type_annotation); + if (pointer_type.has_value()) { + OnArrayDeclaration(let_var, pointer_type.value(), 0, + BufferVarInfo::kLetNode); + } else if (allow_untyped_pointers_) { + OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); + } else { + LOG(FATAL) << "Let statement of variable " << let_var->name_hint + << " is missing a type annotation, " + << "or type annotation is not a pointer to primitive"; + } + } + } + + /* Update the type map for a buffer based on its declaration + * + * @param buffer The VarNode representing the buffer. + * + * @param element_dtype The dtype of a single element of the buffer. + * If unknown, when used with the allow_untyped_handles option, + * should be a handle dtype. + * + * @param extent The extent of the buffer. Zero if size is unknown. + * + * @param declaration_location How the buffer was allocated, so that + * some locations can be rewritten without others. + */ + void + OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent, + BufferVarInfo::DeclarationLocation declaration_location) { + auto it = info_map_.find(buffer.get()); + if (it != info_map_.end()) { + // The same buffer var may appear in more than one Allocate due to + // upstream transforms (e.g., storage planning/merging). Treat repeated + // declarations as benign and merge metadata instead of erroring. + BufferVarInfo &existing = it->second; + // Prefer a concrete element dtype if the previous one was a handle. + if (existing.element_dtype.is_handle() && !element_dtype.is_handle()) { + existing.element_dtype = + element_dtype == DataType::Bool() + ? DataType::Int(8).with_lanes(element_dtype.lanes()) + : element_dtype; + } + // If extent was previously unknown (0) and a concrete extent is + // provided now, record it. + if (!existing.extent.defined() || is_zero(existing.extent)) { + existing.extent = extent; + } + // Merge declaration locations (bitwise OR of flags). + existing.declaration_location = + static_cast( + existing.declaration_location | declaration_location); + return; + } + + if (element_dtype == DataType::Bool()) { + element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); + } + info_map_[buffer.get()] = BufferVarInfo{ + buffer, element_dtype, std::move(extent), declaration_location}; + } + + /* Update the type map for a buffer based on its usage + * + * @param value_dtype The dtype of the value being stored to or + * loaded from the buffer. + * + * @param buffer The VarNode representing the buffer. + * + * @param indices The index at which the value is being stored/loaded. + * + * @param is_buffer_load Whether the access is BufferLoad + */ + void OnArrayAccess(DataType value_dtype, const VarNode *buffer, + const Array &indices, bool is_buffer_load) { + auto it = info_map_.find(buffer); + ICHECK(it != info_map_.end()) + << "Load/Store of buffer " << buffer->name_hint << " (" << buffer + << ") occurred before its declaration."; + + if (value_dtype.is_scalable_vector()) { + // Scalable types are not currently supported in storage_rewrite. Scalable + // buffer accesses are not currently checked and therefore are not + // rewritten. + return; + } + + BufferVarInfo &var_info = it->second; + + if (value_dtype.element_of() == DataType::Bool()) { + value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes()); + } + + if (var_info.element_dtype.is_handle()) { + ICHECK(allow_untyped_pointers_) + << "Variable " << buffer->name_hint + << " was missing a type annotation in its declaration"; + var_info.element_dtype = value_dtype.element_of(); + } + + for (int i = 0; i < static_cast(indices.size()) - 1; i++) { + ICHECK(indices[i].dtype().is_scalar()) + << "Only the last index of a buffer access may be a vector type."; + } + int index_lanes = !indices.empty() ? indices.back().dtype().lanes() : 1; + + DataType access_dtype = value_dtype; + + int lanes_used = var_info.element_dtype.lanes(); + + // This can happen due to a previous pass that had rewrite_store_load = + // false. This occurs from the StorageRewrite in tvm::lower, followed by + // the PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = + // false is necessary because the C-based codegens do not yet support + // vectorized pointer types (e.g. float16x4*). Once they do, this if + // statement should instead be replaced by the below ICHECK_EQ. + if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) { + ICHECK_EQ(index_lanes, value_dtype.lanes()); + lanes_used = 1; + var_info.element_dtype = var_info.element_dtype.with_lanes(1); + } + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), + // value_dtype.lanes()) + // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of + // data with " + // << index_lanes << " indices into an array whose elements have " + // << var_info.element_dtype.lanes() << " lanes. " + // << "Expected output with " << index_lanes * + // var_info.element_dtype.lanes() + // << " lanes."; + + // If the index is a RampNode with stride of 1 and offset + // divisible by the number of number of lanes, and the predicate + // does not apply any masking, then this array access could be + // vectorized. + if (!indices.empty()) { + const RampNode *ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { + if (ramp_index->lanes->IsInstance()) { + int lanes = + static_cast(Downcast(ramp_index->lanes)->value); + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + lanes_used = lanes; + } + } + } + } + + if (detect_scalar_read_patterns_ && is_buffer_load && !indices.empty()) { + const PrimExpr last_dim_index = indices[indices.size() - 1]; + if (last_dim_index.dtype().lanes() == 1) { + arith::ModularSet me = analyzer_.modular_set(last_dim_index); + var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff)); + return; + } + } + var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); + } + + // Map of buffer variable information determined + std::unordered_map info_map_; + + // + bool allow_untyped_pointers_{false}; + // Whether to detect scalar read patterns for rewriting to vector shuffle + bool detect_scalar_read_patterns_{true}; + + // internal analyzer + arith::Analyzer analyzer_; +}; + +/* \brief Rewrites buffer/pointer variables from scalar types to vectorized + * types. + * + * Some runtimes do not allow casting between composite types and the underlying + * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*). + * In these cases, in order to have vectorized load/store on an array, the + * element type of that array must be vectorized. This is in contrast to + * C-style runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + + * offset)` is valid. + * + * By default, VectorTypeRewriter will attempt to rewrite all buffer variables + * to vectorized access, if the load/store occurring in the PrimFunc are all + * vectorized. This includes adjusting the indices being used to access the + * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4* + * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to + * `vec_arr[offset/4]`.) + * + * Currently, several of the C-style runtimes do not support buffers whose + * elements are vectorized types, or rely on the presence of the Ramp nodes to + * identify vectorized loads. The boolean parameters in the constructor are to + * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these + * runtimes. Once all runtimes support vectorized buffer elements, these + * parameters can be removed. + */ +class VectorTypeRewriter : public StmtExprMutator { +public: + /* Constructor + * + * @param checker The VectorTypeAccessChecker that has previously read out + * information from the PrimFunc + * + * @param rewrite_params Whether pointer-type parameters passed into the + * function should be rewritten from scalar types to vectorized types. + * + * @param rewrite_buffer_map Whether buffers present in the buffer_map should + * have their data variable be rewritten from scalar types to vectorized + * types. + * + * @param rewrite_allocate_node Whether the buffer variable associated with + * AllocateNodes should be rewritten from scalar types to vectorized types. + * + * @param rewrite_indices Whether the indices to the Load and Store nodes + * should be rewritten to correspond to the new buffer_var type. + * + * @param rewrite_let_node Whether pointer declarations in let nodes + * should be re-written. + */ + VectorTypeRewriter( + const std::unordered_map &info_map, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true, bool rewrite_allocate_const_node = true, + bool rewrite_scalar_read_to_vector_shuffle = true) + : rewrite_indices_(rewrite_indices) { + int rewrite_mask = 0; + if (rewrite_params) { + rewrite_mask |= BufferVarInfo::kPrimFuncParam; + } + if (rewrite_buffer_map) { + rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap; + } + if (rewrite_allocate_node) { + rewrite_mask |= BufferVarInfo::kAllocateNode; + } + if (rewrite_let_node) { + rewrite_mask |= BufferVarInfo::kLetNode; + } + if (rewrite_allocate_const_node) { + rewrite_mask |= BufferVarInfo::kAllocateConstNode; + } + + // Rewrite any buffer variables whose preferred type isn't their current + // type. + for (const auto &pair : info_map) { + const auto &var_info = pair.second; + DataType preferred = var_info.get_preferred_dtype(); + if (preferred != var_info.element_dtype && + (rewrite_mask & var_info.declaration_location)) { + Var old_buffer_var = var_info.var; + Var new_buffer_var(old_buffer_var->name_hint, + PointerType(PrimType(preferred), + GetPtrStorageScope(old_buffer_var)), + old_buffer_var->span); + + rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var, + var_info.element_dtype, preferred}; + } + } + } + + /*! + * \brief Mutator for BufferLoad or BufferStore. + * \return The rewritten node and the shuffle index. (Only for BufferLoad) + * When the shuffle index is non-negative, the caller should generate Shuffle + * to extract the element from the vector. + */ + template std::pair VisitBufferAccess(Node node) { + int shuffle_index = -1; + if (!rewrite_indices_) { + return {node, shuffle_index}; + } + + auto it = rewrite_map_.find(node->buffer->data.get()); + if (it == rewrite_map_.end()) { + return {node, shuffle_index}; + } + const auto &info = it->second; + + Array indices = node->indices; + const PrimExpr &last_dim_index = indices[indices.size() - 1]; + const RampNode *ramp_index = indices[indices.size() - 1].as(); + + if (node->buffer->dtype.is_scalable_vector() || + last_dim_index.dtype().is_scalable_vector()) { + // Scalable types are not currently supported in storage_rewrite. Scalable + // buffer accesses are not currently checked and therefore are not + // rewritten. + return {node, shuffle_index}; + } + + if (ramp_index && is_one(ramp_index->stride) && + ramp_index->lanes->IsInstance()) { + int lanes = static_cast(Downcast(ramp_index->lanes)->value); + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), lanes); + if (lanes != info.factor()) { + ICHECK(info.factor() && lanes % info.factor() == 0); + int new_lanes = lanes / info.factor(); + new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, + ramp_index->span); + } + indices.Set(indices.size() - 1, new_index); + } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) { + arith::ModularSet me = analyzer_.modular_set(last_dim_index); + ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); + PrimExpr new_index = + last_dim_index / make_const(last_dim_index.dtype(), info.factor()); + shuffle_index = me->base % info.factor(); + ; + indices.Set(indices.size() - 1, new_index); + } + + auto writer = node.CopyOnWrite(); + writer->buffer = RemapBuffer(node->buffer); + writer->indices = indices; + return {node, shuffle_index}; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto [modified, shuffle_index] = VisitBufferAccess(node); + + // Not needed for BufferStoreNode, so we can't just call + // LegalizeDtype() in VisitBufferAccess. + if (node.same_as(modified)) { + return std::move(node); + } else { + auto writer = modified.CopyOnWrite(); + // writer->LegalizeDType(); + LegalizeBufferLoadDType(writer); + if (shuffle_index >= 0) { + return Shuffle::ExtractElement(std::move(modified), shuffle_index); + } + return std::move(modified); + } + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + auto [modified, shuffle_index] = VisitBufferAccess(std::move(node)); + ICHECK(shuffle_index < 0); + return std::move(modified); + } + + Stmt VisitStmt_(const LetStmtNode *op) final { + auto it = rewrite_map_.find(op->var.get()); + PrimExpr value = this->VisitExpr(op->value); + Stmt body = this->VisitStmt(op->body); + Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; + if (var.same_as(op->var) && value.same_as(op->value) && + body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } + return LetStmt(var, value, body); + } + + Buffer RemapBuffer(Buffer buf) { + auto cache_key = buf.get(); + + auto cache_it = buffer_map_.find(cache_key); + if (cache_it != buffer_map_.end()) { + return cache_it->second; + } + + auto info_it = rewrite_map_.find(buf->data.get()); + if (info_it != rewrite_map_.end()) { + auto &info = info_it->second; + + Array shape = buf->shape; + PrimExpr last_dim = shape[shape.size() - 1]; + shape.Set(shape.size() - 1, + last_dim / make_const(last_dim.dtype(), info.factor())); + + auto writer = buf.CopyOnWrite(); + writer->data = info.new_buffer_var; + writer->dtype = info.new_element_dtype; + writer->shape = shape; + } + + buffer_map_[cache_key] = buf; + return buf; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + if (!rewrite_indices_) { + return expr; + } + + const VarNode *buffer_var = op->args[1].as(); + auto it = rewrite_map_.find(buffer_var); + if (it == rewrite_map_.end()) { + return expr; + } + const auto &info = it->second; + + PrimExpr index = op->args[2]; + PrimExpr extent = op->args[3]; + PrimExpr flag = op->args[4]; + + PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); + int factor = info.factor(); + extent = extent / make_const(extent.dtype(), factor); + index = index / make_const(index.dtype(), factor); + Array acc_args{e_dtype, info.new_buffer_var, index, extent, + flag}; + return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); + + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + + Stmt VisitStmt_(const AllocateNode *op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + + const auto &info = it->second; + + Var new_buffer_var = info.new_buffer_var; + + Array extents = op->extents; + PrimExpr last_extent = extents[extents.size() - 1]; + extents.Set(extents.size() - 1, + last_extent / make_const(last_extent.dtype(), info.factor())); + DLOG(INFO) << "Allocate with " << new_buffer_var << " and " + << info.new_element_dtype << " extents: " << extents; + return Allocate(new_buffer_var, info.new_element_dtype, extents, + op->condition, op->body, op->annotations); + } + + Stmt VisitStmt_(const AllocateConstNode *op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + + const auto &info = it->second; + + Var new_buffer_var = info.new_buffer_var; + + int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); + + Array extents = op->extents; + extents.Set(extents.size() - 1, extents[extents.size() - 1] / + make_const(extents[0].dtype(), factor)); + return AllocateConst(new_buffer_var, info.new_element_dtype, extents, + op->data, op->body); + } + + /* Update the parameters and all remaining variable references + * + * Should be called after calling operator() on the body of the + * function. + * + * @param func A pointer to the PrimFunc being modified. + */ + void Finalize(PrimFunc *func_ptr) { + ICHECK(func_ptr) << "Finalize expects a non-null pointer"; + auto &func = *func_ptr; + auto *n = func.CopyOnWrite(); + + // Remap any remaining references to the old buffer variables + Map var_remap; + for (const auto &pair : rewrite_map_) { + const auto &info = pair.second; + var_remap.Set(info.old_buffer_var, info.new_buffer_var); + } + n->body = Substitute(n->body, var_remap); + + // Remap the argument list to use the new buffer variables. + Array new_params; + for (const auto &old_param : n->params) { + auto it = rewrite_map_.find(old_param.get()); + if (it == rewrite_map_.end()) { + new_params.push_back(old_param); + } else { + const auto &info = it->second; + new_params.push_back(info.new_buffer_var); + } + } + n->params = new_params; + + // Remap the Buffer objects in PrimFunc::buffer_map so that the + // buffers use the new buffer variables + Map new_buffer_map; + for (const auto &pair : n->buffer_map) { + Var key = pair.first; + Buffer old_buffer = pair.second; + Var old_var = old_buffer->data; + Buffer new_buffer = RemapBuffer(old_buffer); + new_buffer_map.Set(key, new_buffer); + } + n->buffer_map = new_buffer_map; + } + +private: + struct RewriteInfo { + Var old_buffer_var; + Var new_buffer_var; + DataType old_element_dtype; + DataType new_element_dtype; + + int factor() const { + int old_lanes = old_element_dtype.lanes(); + int new_lanes = new_element_dtype.lanes(); + ICHECK_EQ(new_lanes % old_lanes, 0); + return new_lanes / old_lanes; + } + }; + + bool rewrite_indices_{true}; + std::unordered_map rewrite_map_; + std::unordered_map buffer_map_; + arith::Analyzer analyzer_; +}; + +// Rewrite allocates, pointer parameters, and buffer map into vectorized +// versions if each access into a buffer is the same vector type. +PrimFunc PointerValueTypeRewrite( + PrimFunc f, bool allow_untyped_pointers = false, bool rewrite_params = true, + bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, + bool rewrite_indices = true, bool rewrite_let_node = true, + bool rewrite_allocate_const_node = true, + bool rewrite_scalar_read_to_vector_shuffle = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, + allow_untyped_pointers, + rewrite_scalar_read_to_vector_shuffle); + checker(f->body); + + VectorTypeRewriter rewriter( + checker.info_map_, rewrite_params, rewrite_buffer_map, + rewrite_allocate_node, rewrite_indices, rewrite_let_node, + rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle); + PrimFuncNode *n = f.CopyOnWrite(); + n->body = rewriter(std::move(n->body)); + rewriter.Finalize(&f); + + return f; +} + +using namespace tir::transform; +namespace transform { +Pass StorageRewrite() { + auto pass_func = [](PrimFunc f, const IRModule &m, PassContext ctx) { + bool detect_inplace = + ctx->GetConfig(kStorageRewriteDetectInplace, Bool(false)).value(); + bool enable_reuse = true; + bool reuse_require_exact_matched_dtype = false; + bool merge_static_smem = + ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + AllocateCollector collector; + collector(f->body); + bool has_dynamic = collector.dyn_shmem_allocs_.size() > 1; + if (has_dynamic || merge_static_smem) { + // For IRModule utilizing dynamic shared memory, reuse is not enabled + // Because dynamic doesn't require maintaining the readability and + // it benefits from a more optimized allocation strategy through the + // Pass `MergeSharedMemoryAllocations`. + // When `merge_static_smem` is true, we will reuse and merge shared + // memory in a dedicated pass `MergeSharedMemoryAllocations`. + // And so we don't enable reuse in this pass. + enable_reuse = false; + } + + Optional target = f->GetAttr("target"); + if (target.defined() && (target.value()->kind->name == "vulkan" || + target.value()->kind->name == "webgpu")) { + // Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU + reuse_require_exact_matched_dtype = true; + } + Map local_var_init_map; + if (auto init_map = + f->attrs.GetAttr>(tl::attr::kLocalVarInit)) { + local_var_init_map = init_map.value(); + } + auto *n = f.CopyOnWrite(); + StoragePlanRewriter plan_rewriter; + n->body = plan_rewriter.Rewrite( + std::move(n->body), detect_inplace, enable_reuse, + reuse_require_exact_matched_dtype, std::move(local_var_init_map)); + // Parameters may not be rewritten, but internal allocations may. + // Vectorization of AllocateConst is currently disabled, as it has + // indexing issues for types that include padding (e.g. int8x3 + // padded out to 32 bits) would require either rewriting + // AllocateConst::data, or would require the code generators to + // handle vectorized constants. + return PointerValueTypeRewrite(std::move(f), true, false, false, false, + true, true, false, false); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite); +} + +Pass PointerValueTypeRewrite() { + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return tl::PointerValueTypeRewrite(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite", + PointerValueTypeRewrite); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/thread_storage_sync.cc b/tilelang/original/src/transform/thread_storage_sync.cc new file mode 100644 index 0000000000000000000000000000000000000000..0627678e18ad1eed26ec83a8842b29755fb60fd3 --- /dev/null +++ b/tilelang/original/src/transform/thread_storage_sync.cc @@ -0,0 +1,860 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file thread_storage_sync.cc + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/builtin.h" +#include "./common/thread_sync_types.h" +#include "./storage_access.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; + +class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor { +public: + explicit TileLangThreadSyncPlanner(StorageScope sync_scope) + : sync_scope_(std::move(sync_scope)) {} + + // The syncs inserted before each statement + std::unordered_set syncs_inserted_; + +protected: + bool Enabled(const VarNode *buf, const StorageScope &scope) const final { + return in_device_env() && scope == sync_scope_; + } + // Plan the sync + std::vector Summarize(std::vector seq, + const ForNode *loop) final { + // Redirect all "shared.dyn" buffer access to the same buffer var + // so that the accesses can be planned together. + Var shared_dyn_buf; + for (StmtEntry &entry : seq) { + for (AccessEntry &access : entry.access) { + if (access.scope.rank == StorageRank::kShared && + access.scope.tag == ".dyn" && access.buffer.defined()) { + if (!shared_dyn_buf.defined()) { + shared_dyn_buf = access.buffer; + } else { + access.buffer = shared_dyn_buf; + } + } + } + } + + // Unsynced reads and writes + std::vector reads; + std::vector writes; + // if it is a loop, rotate two times to consider effect of loop. + // simulation based approach to find dependencies + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry &s = seq[i]; + // check if sync before statement is needed. + bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); + // Apply the syncs added already. + + if (sync_before_stmt) { + reads.clear(); + writes.clear(); + } + + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead) { + if (FindConflict(writes, acc, false)) { + sync_before_stmt = true; + break; + } + } else if (acc.type == kWrite) { + if (FindConflict(reads, acc, false) || + FindConflict(writes, acc, false)) { + sync_before_stmt = true; + break; + } + } else if (acc.type == kSync) { + reads.clear(); + writes.clear(); + } + } + // If sync is inserted. remove the irrelevant things. + if (sync_before_stmt) { + reads.clear(); + writes.clear(); + } + // Add the read/write of current statement + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead) { + reads.push_back(acc); + } else if (acc.type == kWrite) { + writes.push_back(acc); + } else if (acc.type == kSync) { + reads.clear(); + writes.clear(); + } + } + + if (sync_before_stmt) { + insert_syncs(s.stmt); + } + } + if (loop != nullptr) { + // Check if the loop body contains any reads in the same sync scope. + // If there are reads, we conservatively keep the sync within the loop + // body to preserve per-iteration ordering when needed. If there are no + // reads (e.g., only writes to shared.dyn), we can safely hoist the sync + // to before the loop to avoid redundant barriers. + bool has_read_in_scope = false; + for (const StmtEntry &s : seq) { + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead && acc.scope == sync_scope_) { + has_read_in_scope = true; + break; + } + } + if (has_read_in_scope) + break; + } + // If there is a loop-carried dependency, insert a single sync + // before the loop rather than hoisting a sync into the loop body. + // This reduces redundant per-iteration synchronizations for cases + // where each iteration touches disjoint regions (e.g., stmatrix + // writes to shared.dyn) and only a global ordering before/after the + // loop is required. + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry &s = seq[i]; + if (syncs_inserted_.count(s.stmt) != 0) + break; + if (reads.empty() && writes.empty()) + break; + bool need_loop_sync = false; + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead) { + if (FindConflict(writes, acc, true)) { + need_loop_sync = true; + break; + } + } else if (acc.type == kWrite) { + if (FindConflict(reads, acc, true) || + FindConflict(writes, acc, true)) { + need_loop_sync = true; + break; + } + } else if (acc.type == kSync) { + reads.clear(); + writes.clear(); + } + } + if (need_loop_sync) { + if (!has_read_in_scope) { + // Mark the loop itself to receive a sync before it, instead of + // inserting inside the loop body. This ensures a single sync is + // emitted outside the loop and avoids per-iteration overhead. + insert_syncs(loop); + } else { + // Fall back to inserting before the first conflicting statement + // inside the loop to maintain correctness when reads are present. + insert_syncs(s.stmt); + } + break; + } + } + } + // return the exposed entries, remove unnecessary ones. + int sync_count = 0; + // head are before first sync, tail are after last sync + std::vector head, tail; + AccessEntry esync; + esync.threads = this->env_threads(); + esync.thread_range = this->ComputeThreadRange(esync.threads); + esync.type = kSync; + esync.scope = sync_scope_; + + for (const StmtEntry &s : seq) { + if (syncs_inserted_.count(s.stmt)) { + if (sync_count != 0) { + tail.clear(); + } else { + head.push_back(esync); + } + ++sync_count; + } + for (const AccessEntry &acc : s.access) { + if (acc.type == kSync) { + if (sync_count != 0) { + tail.clear(); + } else { + head.push_back(esync); + } + ++sync_count; + } else { + if (sync_count != 0) { + tail.push_back(acc); + } else { + head.push_back(acc); + } + } + } + } + head.insert(head.end(), tail.begin(), tail.end()); + if (loop != nullptr) { + // clear double buffer flag after a loop is finished. + for (AccessEntry &e : head) { + e.double_buffer_write = false; + } + } + return head; + } + +private: + // find conflicting entry in vec. + bool FindConflict(const std::vector &prev, + const AccessEntry &curr, bool loop_carry) { + for (const AccessEntry &x : prev) { + if (FindConflict(x, curr, loop_carry)) { + return true; + } + } + return false; + } + + bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, + bool loop_carry) { + // Special case: ignore conflicts between async-copy writes (e.g., TMA + // loads into shared memory). Multiple async writes do not require + // interspersed barriers among themselves. We still respect conflicts with + // reads to ensure visibility before consumption. + if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy && + curr.is_async_copy) { + return false; + } + // Access to different buffers does not conflict. + if (!prev.buffer.same_as(curr.buffer)) { + return false; + } + + // Assumes no race between threads + // Same index value means no conflicts + // TODO(tqchen) more standard set based testing. + bool has_same_index = true; + bool range_is_equal = true; + bool range_is_overlap = true; + + for (const auto &kv : prev.thread_range) { + if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) { + range_is_equal = false; + break; + } + } + + if (prev.buffer_indices.size() != curr.buffer_indices.size()) { + // They are not the same indices, should be conflict. + return true; + } + if (prev.is_pointer_access || curr.is_pointer_access) { + // For accesses created via tvm_access_ptr we may still be able to prove + // disjointness using their byte ranges. If both sides expose a touched + // interval and we can show they don't overlap, skip the conflict. + if (prev.is_pointer_access && curr.is_pointer_access && + PointerAccessIsDisjoint(prev, curr)) { + return false; + } + // Otherwise fall back to the conservative answer: treat them as + // overlapping. + return true; + } + + for (size_t i = 0; i < prev.buffer_indices.size(); i++) { + auto prev_dtype = prev.dtype; + auto curr_dtype = curr.dtype; + + const auto &prev_indice = prev.buffer_indices[i]; + const auto &curr_indice = curr.buffer_indices[i]; + + if (!ExprDeepEqual()(prev_indice, curr_indice)) { + PrimExpr prev_indice_bytes = + analyzer_.Simplify(prev_indice * prev_dtype.bytes()); + PrimExpr curr_indice_bytes = + analyzer_.Simplify(curr_indice * curr_dtype.bytes()); + + has_same_index = false; + + // If both are const, we can check if they are disjoint + // by checking if the bounds are disjoint + // [1024, 2048], [2048, 3072] are disjoint + // [1024, 2048], [1024, 1024] are not disjoint + auto prev_bound = analyzer_.const_int_bound(prev_indice_bytes); + auto curr_bound = analyzer_.const_int_bound(curr_indice_bytes); + if (prev_bound.defined() && curr_bound.defined()) { + if ((prev_bound->min_value) > (curr_bound->max_value) || + (curr_bound->min_value) > (prev_bound->max_value)) { + range_is_overlap = false; + break; + } + } + + // if we can prove prev_indice < curr_indice or prev_indice > + // curr_indice, then they are not overlap + auto prev_indices_dtype = prev_indice.dtype(); + auto curr_indices_dtype = curr_indice.dtype(); + if (prev_indices_dtype.lanes() != curr_indices_dtype.lanes()) { + // can not support different lanes binary op like <, >, <=, >= + // skip otherwise it will lead to error + continue; + } + + // provably disjoint means no overlap, for example: + // we can prove that tx - 128 < tx + 128, tx in [0, 128] + // However, we should apply tx split because + // tx < tx + 32 when tx in [0, 128] is not disjoint + // because [0, 128] is not disjoint with [32, 160] + // so we should split tx into tx0 and tx1. + + struct ThreadVarInfo { + const char *name_prev; + const char *name_curr; + IterVar iv; + } thread_vars[] = { + {"tx1", "tx2", tx_}, + {"ty1", "ty2", ty_}, + {"tz1", "tz2", tz_}, + }; + + for (const auto &info : thread_vars) { + Var prev_var(info.name_prev, info.iv->var.dtype()); + Var curr_var(info.name_curr, info.iv->var.dtype()); + analyzer_.Bind(prev_var, info.iv->dom); + analyzer_.Bind(curr_var, info.iv->dom); + prev_indice_bytes = + Substitute(prev_indice_bytes, {{info.iv->var, prev_var}}); + curr_indice_bytes = + Substitute(curr_indice_bytes, {{info.iv->var, curr_var}}); + } + + bool provably_disjoint = + analyzer_.CanProve(prev_indice_bytes < curr_indice_bytes, + arith::ProofStrength::kSymbolicBound) || + analyzer_.CanProve(prev_indice_bytes > curr_indice_bytes, + arith::ProofStrength::kSymbolicBound); + + if (provably_disjoint) { + range_is_overlap = false; + break; + } + } + + if (!has_same_index) { + break; + } + } + + if (has_same_index && range_is_equal) { + return false; + } + + // If this is a read into a double buffer that was previously + // swapped out, then it doesn't conflict. + if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { + return false; + } + + // If nothing else allows sharing the same buffer, then they are + // in conflict. + // if range_is_overlap is true, then they are in conflict, we should return + // true. if range_is_overlap is false, then they are not in conflict, we + // should return false. + return range_is_overlap; + } + + bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) { + if (lhs.touched.size() != 1 || rhs.touched.size() != 1) { + return false; + } + PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min()); + PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max()); + PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min()); + PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max()); + + if (analyzer_.CanProve(lhs_max < rhs_min, + arith::ProofStrength::kSymbolicBound)) { + return true; + } + if (analyzer_.CanProve(rhs_max < lhs_min, + arith::ProofStrength::kSymbolicBound)) { + return true; + } + return false; + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + tx_ = iv; + } else if (iv->thread_tag == "threadIdx.y") { + ty_ = iv; + } else if (iv->thread_tag == "threadIdx.z") { + tz_ = iv; + } + } + TileLangStorageAccessVisitor::VisitStmt_(op); + } + + void insert_syncs(const Object *obj) { + if (syncs_inserted_.count(obj)) + return; + syncs_inserted_.insert(obj); + } + +private: + // Member variables + IterVar tx_ = + IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar); + IterVar ty_ = + IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar); + IterVar tz_ = + IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar); + // synchronization scope + StorageScope sync_scope_; +}; + +// There are cases where necessary syncthreads is not inserted by +// ThreadSyncInserter. For example, syncthreads is needed after async_wait_queue +// in the second loop below, but since ThreadSyncInserter is not aware of the +// asynchronous semantics, it cannot tell that the syncthreads is needed there. +// +// // Pipeline prologue +// for i in range(125): +// async_commit_queue(0): +// async_scope: +// shared[(i + 3) % 4] = ... +// ... +// +// // Pipeline Epilogue +// for i in range(3): +// async_wait_queue(0, 2 - i): +// local[...] = shared[(i + 125) % 4] + +// This class adds syncthreads after all async_wait_queue. That includes +// syncthreads that can be inserted by ThreadSyncInserter as well, but +// ThreadSyncInserter will not insert duplicate syncthreads if it finds an +// existing one at the synchronization point. +class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { +public: + explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) + : sync_scope_(std::move(sync_scope)) {} + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) { + auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string())})); + auto inner = op->body.as(); + ICHECK(inner && + inner->attr_key == tvm::tir::attr::async_wait_inflight_count); + auto zero = make_zero(DataType::Int(32)); + auto new_body = SeqStmt({sync, inner->body}); + return AttrStmt(zero, tvm::tir::attr::async_wait_queue_scope, op->value, + AttrStmt(zero, tvm::tir::attr::async_wait_inflight_count, + inner->value, new_body)); + } + return StmtExprMutator::VisitStmt_(op); + } + +private: + StorageScope sync_scope_; +}; + +class ThreadSyncInserter : public StmtExprMutator { +public: + ThreadSyncInserter(StorageScope sync_scope, + const std::unordered_set &syncs) + : sync_scope_(std::move(sync_scope)), syncs_(syncs) {} + + Stmt VisitStmt(const Stmt &stmt) final { + if (syncs_.empty()) + return stmt; + if (syncs_.count(stmt.get())) { + Stmt barrier; + if (sync_scope_.rank == StorageRank::kGlobal) { + barrier = MakeGlobalBarrier(); + } else { + barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string())})); + } + // Mutate after query, to avoid stmt change. + auto ret = StmtExprMutator::VisitStmt(stmt); + ret = SeqStmt({barrier, ret}); + return ret; + } else { + return StmtExprMutator::VisitStmt(stmt); + } + } + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].read_count; + } + return StmtExprMutator::VisitExpr_(op); + } + Stmt VisitStmt_(const BufferStoreNode *op) final { + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].write_count; + } + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::tir::attr::thread_extent) { + bool temp = true; + std::swap(temp, in_thread_env_); + thread_extents_.push_back(op); + Stmt ret = StmtExprMutator::VisitStmt_(op); + thread_extents_.pop_back(); + std::swap(temp, in_thread_env_); + // first thread scope. + if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { + ret = InitGlobalBarrier(ret.as()); + num_blocks_ = PrimExpr(); + is_lead_ = PrimExpr(); + } + return ret; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK_EQ(op->args.size(), 5U); + Var buffer_var(Downcast(op->args[1])); + const IntImmNode *flag = op->args[4].as(); + if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].read_count; + } + if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].write_count; + } + return expr; + } else if (op->op.same_as(builtin::address_of())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK_EQ(op->args.size(), 1U) + << "address_of should only have one argument (Buffer)"; + + if (auto load = op->args[0].as()) { + Var buffer_var(Downcast(load->buffer->data)); + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].read_count; + } + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].write_count; + } + return expr; + } else { + return StmtExprMutator::VisitExpr_(op); + } + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + +private: + // RW statistics about data + struct Entry { + int read_count{0}; + int write_count{0}; + }; + + // Get current storage scope. + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); + } + + // private functions. + Stmt InitGlobalBarrier(const AttrStmtNode *op) { + ICHECK(op != nullptr); + Array pargs = { + StringImm(runtime::symbol::tvm_prepare_global_barrier)}; + Stmt prep = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); + Stmt body = op->body; + for (const auto &kv : rw_stats_) { + const auto &e = kv.second; + if (e.read_count != 0 && e.write_count != 0) { + body = AttrStmt(kv.first, tvm::tir::attr::volatile_scope, 1, body); + } + } + rw_stats_.clear(); + Stmt kinit = Evaluate( + Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {})); + body = SeqStmt({kinit, body}); + body = AttrStmt(op->node, op->attr_key, op->value, body); + return SeqStmt({prep, body}); + } + Stmt MakeGlobalBarrier() { + ICHECK(sync_scope_.rank == StorageRank::kGlobal); + if (!num_blocks_.defined()) { + ICHECK(!is_lead_.defined()); + num_work_dim_ = thread_extents_.size(); + for (const AttrStmtNode *attr : thread_extents_) { + IterVar iv = Downcast(attr->node); + runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); + if (s.rank == 0) { + num_blocks_ = + (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); + } else if (s.rank == 1) { + PrimExpr cond = iv->var == make_zero(iv->var.dtype()); + is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; + } + } + } else { + ICHECK_EQ(num_work_dim_, thread_extents_.size()); + } + return Evaluate( + Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); + } + // data structure. + StorageScope sync_scope_; + const std::unordered_set &syncs_; + + // The read write statistics of storage + std::unordered_map rw_stats_; + // The statistics for global barrier + bool in_thread_env_{false}; + // memorized results + std::vector thread_extents_; + size_t num_work_dim_{0}; + PrimExpr num_blocks_; + PrimExpr is_lead_; +}; + +class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { +public: + static Stmt Rewrite(Stmt stmt) { + arith::Analyzer analyzer; + ThreadPartialSyncRewriter rewriter(&analyzer); + return rewriter(std::move(stmt)); + } + +private: + explicit ThreadPartialSyncRewriter(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {} + + Stmt VisitStmt_(const EvaluateNode *op) final { + const CallNode *call = nullptr; + if (op->value->IsInstance()) { + call = op->value.as(); + if (call->op.same_as(builtin::tvm_storage_sync())) { + const auto &args = call->args; + ICHECK(!args.empty()); + const auto *scope_node = args[0].as(); + ICHECK(scope_node != nullptr); + const std::string &scope = scope_node->value; + + if (args.size() != 1 || (scope != "shared" && scope != "shared.dyn")) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + return ProcessSharedSync(call, scope); + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt ProcessSharedSync(const CallNode *op, const std::string &scope) { + // Get thread bounds + auto bound_tx = analyzer_->const_int_bound(tx_); + auto bound_ty = analyzer_->const_int_bound(ty_); + auto bound_tz = analyzer_->const_int_bound(tz_); + + // Check if all threads are participating (full extent) + if (IsFullThreadExtent(tx_, bound_tx) && + IsFullThreadExtent(ty_, bound_ty) && + IsFullThreadExtent(tz_, bound_tz)) { + return Evaluate(IRMutatorWithAnalyzer::VisitExpr_(op)); + } + + // Calculate thread extents + auto extent_tx = CalculateThreadExtent(tx_, bound_tx); + auto extent_ty = CalculateThreadExtent(ty_, bound_ty); + auto extent_tz = CalculateThreadExtent(tz_, bound_tz); + + // Create or get barrier info + ThreadBoundKey key{bound_tx->min_value, bound_tx->max_value, + bound_ty->min_value, bound_ty->max_value, + bound_tz->min_value, bound_tz->max_value}; + + auto [barrier_id, thread_count] = + GetOrCreateBarrier(key, extent_tx, extent_ty, extent_tz); + if (thread_count % 32 != 0) { + // TODO(lei): This is a workaround for the case where the thread count is + // not a multiple of 32. we should enhance the pass to analysis index + // instead of buffer expression etc. + return Stmt(); + } + + // Create new sync call with barrier info + Array new_args = {StringImm(scope), + IntImm(DataType::Int(32), barrier_id), + IntImm(DataType::Int(32), thread_count)}; + return Evaluate(Call(op->dtype, op->op, new_args)); + } + + std::pair GetOrCreateBarrier(const ThreadBoundKey &key, + size_t extent_tx, + size_t extent_ty, + size_t extent_tz) { + if (barrier_id_map_.count(key)) { + return {barrier_id_map_[key], thread_count_map_[key]}; + } + + size_t barrier_id = + barrier_id_map_.size() + + static_cast(ReservedNamedBarriers::kFirstUsedBarrier); + size_t thread_count = extent_tx * extent_ty * extent_tz; + + barrier_id_map_[key] = barrier_id; + thread_count_map_[key] = thread_count; + + return {barrier_id, thread_count}; + } + + size_t CalculateThreadExtent(const IterVar &iv, + const arith::ConstIntBound &bound) { + if (!analyzer_->const_int_bound.IsBound(iv->var)) { + return 1; + } + return bound->max_value - bound->min_value + 1; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + tx_ = iv; + } else if (iv->thread_tag == "threadIdx.y") { + ty_ = iv; + } else if (iv->thread_tag == "threadIdx.z") { + tz_ = iv; + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + bool IsFullThreadExtent(const IterVar &iv, + const arith::ConstIntBound &bound) { + if (!analyzer_->const_int_bound.IsBound(iv->var)) { + return true; + } + + if (!iv->dom.defined()) { + return true; + } + + const auto *min_node = iv->dom->min.as(); + const auto *extent_node = iv->dom->extent.as(); + + int64_t min = min_node->value; + int64_t extent = extent_node->value; + int64_t max = min + extent - 1; + + return min == bound->min_value && max == bound->max_value; + } + + // Member variables + IterVar tx_ = + IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar); + IterVar ty_ = + IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar); + IterVar tz_ = + IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar); + std::unordered_map barrier_id_map_; + std::unordered_map thread_count_map_; +}; + +PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) { + StorageScope sync_scope = StorageScope::Create(storage_scope); + auto *n = func.CopyOnWrite(); + auto stmt = n->body; + if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) { + stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt); + } + TileLangThreadSyncPlanner planner(sync_scope); + for (const auto &[_, buffer] : func->buffer_map) { + planner.SetBufferDataToBuffer(buffer->data, buffer); + } + planner(stmt); + + stmt = + ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); + n->body = ThreadPartialSyncRewriter::Rewrite(std::move(stmt)); + return func; +} + +using namespace tir::transform; + +namespace transform { + +tvm::transform::Pass ThreadSync(const String &storage_scope) { + auto pass_func = [storage_scope](PrimFunc f, const IRModule &m, + const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + // Check if thread storage sync is disabled + bool disable_syncthreads = + ctx->GetConfig(kDisableThreadStorageSync, Bool(false)).value()->value; + if (disable_syncthreads) { + return f; + } + return tl::TileLangThreadSync(std::move(f), storage_scope); + ; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync); +} + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/vectorize_loop.cc b/tilelang/original/src/transform/vectorize_loop.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7b31e1d773068f51dbc206a9b5b19cda6bb1952 --- /dev/null +++ b/tilelang/original/src/transform/vectorize_loop.cc @@ -0,0 +1,887 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file vectorize_loop.cc + */ +// Loop vectorizer as in Halide pipeline. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "arith/scalable_expression.h" +#include "tir/analysis/check_contains.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace ffi; + +/*! + * \brief Perform data type legalization on the given BufferLoadNode pointer. + * Equal to BufferLoadNode::LegalizeDType, but operates on a pointer. + * \param n A pointer to a writable BufferLoadNode. + */ +static void LegalizeBufferLoadDType(BufferLoadNode *n) { + // Check that all indices except the last one have a scalar dtype + for (int i = 0; i < static_cast(n->indices.size()) - 1; i++) { + ICHECK(n->indices[i].dtype().is_scalar()) + << "Only the last index of a buffer access may be a vector type."; + } + + // If there are no indices, set the dtype to the buffer's dtype + if (n->indices.empty()) { + n->dtype = n->buffer->dtype; + } else { + auto index_dtype = n->indices.back().dtype(); + bool is_buffer_dtype_scalable = n->buffer->dtype.is_scalable_vector(); + bool is_index_scalable = index_dtype.is_scalable_vector(); + + // Do not allow both index dtype and buffer dtype to be scalable vectors + ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) + << "Index dtype and buffer dtype cannot both be scalable."; + + if (is_index_scalable) { + // Index is a scalable vector, while the buffer is not + n->dtype = n->buffer->dtype.with_scalable_vscale_factor( + index_dtype.vscale_factor() * n->buffer->dtype.lanes()); + } else if (is_buffer_dtype_scalable) { + // The buffer is a scalable vector, while the index is not + n->dtype = n->buffer->dtype.with_scalable_vscale_factor( + n->buffer->dtype.vscale_factor() * index_dtype.lanes()); + } else { + // Neither side is a scalable vector, multiply lanes + n->dtype = n->buffer->dtype.with_lanes(index_dtype.lanes() * + n->buffer->dtype.lanes()); + } + } +} + +inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { + if (is_scalable) { + return Mul(Call(DataType::Int(32), builtin::vscale(), {}), + lanes_or_vscale_factor); + } else { + return lanes_or_vscale_factor; + } +} + +inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { + // Check if e is already in the expected form + if (e.dtype().get_lanes_or_vscale_factor() == lanes && + e.dtype().is_scalable_vector() == is_scalable) + return e; + + if (const BroadcastNode *op = e.as()) { + ICHECK(op->dtype.is_scalable_vector() == is_scalable) + << "Can't broadcast between scalable and fixed length vectors."; + int e_lanes = op->dtype.get_lanes_or_vscale_factor(); + + if (lanes % e_lanes == 0) { + return Broadcast(op->value, CreateNewLanes(is_scalable, lanes)); + } + } + + ICHECK(e.dtype().is_scalar()) + << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor() + << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes; + + return Broadcast(e, CreateNewLanes(is_scalable, lanes)); +} + +// Rewrite vectorized allocation access +// This is necessary for making each vector component containing its own +// workspace. Originates from Halide's loop vectorizer +// +// s[i] = s[i * lanes + var] +// +// The same principle applies when using one thread to simulate multiple +// context. +// +class TLVecAllocAccess : public StmtExprMutator { +public: + TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes) + : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {} + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return UpdateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return UpdateBufferAccess(store); + } + +private: + template Node UpdateBufferAccess(Node node) { + // Only update the buffer that's being replaced. + if (node->buffer->data.get() != buf_) { + return node; + } + + // Find/make a Buffer object with the correct updated shape. + Buffer buf; + auto it = buffer_map_.find(node->buffer.get()); + if (it != buffer_map_.end()) { + buf = it->second; + } else { + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + Array shape = node->buffer->shape; + shape.Set(shape.size() - 1, + analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer, implement by appending a + // dimension to the buffer. Since it is currently after the + // flattening, the strides are not technically necessary, but + // are updated for consistency. + + // Update strides if defined. + Array strides; + for (size_t i = 0; i < strides.size(); i++) { + PrimExpr stride = strides[i]; + if (i != strides.size() - 1) { + stride *= var_lanes_; + } + strides.push_back(analyzer_.Simplify(stride)); + } + + // Copy everything into the new buffer. + buf = node->buffer; + auto buf_writer = buf.CopyOnWrite(); + buf_writer->shape = shape; + buf_writer->strides = strides; + buffer_map_[buf.get()] = buf; + } + + return node; + } + + // buffer var + const VarNode *buf_; + // Updated buffer objects. + std::unordered_map buffer_map_; + // variable to be replaced + Var var_; + // the lanes. + PrimExpr var_lanes_; + // Analyzer for simplifications + arith::Analyzer analyzer_; +}; + +// We use ExprFunctor directly instead of StmtExprMutator +// This is because the transformation can change the dtype of the Expr +// The existing ExprMutator transformation rules may not be well defined. +class TLVectorizer : public StmtMutator, + public ExprFunctor { +public: + using ExprFunctor::VisitExpr; + using StmtMutator::operator(); + + // Convenience entry to vectorize a loop body without exposing + // the mutator invocation pattern at call sites. + static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) { + TLVectorizer vec{var, var_lanes}; + auto vec_stmt = vec(std::move(body)); + return vec_stmt; + } + + TLVectorizer(const Var &var, const PrimExpr &var_lanes) + : var_(var), var_lanes_(var_lanes) { + ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); + } + + Stmt VisitStmt(const Stmt &stmt) final { + ICHECK(!need_scalarize_); + Stmt ret = StmtMutator::VisitStmt(stmt); + if (need_scalarize_) { + auto scalarized_stmt = Scalarize(stmt); + need_scalarize_ = false; + return scalarized_stmt; + } else { + return ret; + } + } + + PrimExpr VisitExpr(const PrimExpr &e) final { + return ExprFunctor::VisitExpr(e); + } + + PrimExpr VisitExpr_(const AddNode *op) final { + return AddSubVec( + op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); }); + } + + PrimExpr VisitExpr_(const SubNode *op) final { + return AddSubVec( + op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); }); + } + + PrimExpr VisitExpr_(const MulNode *op) final { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return tvm::ffi::GetRef(op); + } else { + bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); + bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); + if (is_vec_a && is_vec_b) { + // Let's not multiply scalable and fixed length vectors + ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector()) + << "Fixed length and scalable vectors can't be mixed in " + "multiplication."; + } + if (is_vec_a || is_vec_b) { + const RampNode *b_ramp = b.as(); + const RampNode *a_ramp = a.as(); + if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { + PrimExpr lanes = a_ramp->lanes; + return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); + } + if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) { + PrimExpr lanes = b_ramp->lanes; + return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); + } + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int max_lanes = std::max(a_lanes, b_lanes); + bool is_scalable = + a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return Mul(BroadcastTo(a, max_lanes, is_scalable), + BroadcastTo(b, max_lanes, is_scalable)); + } + } + return BinaryVec(op); + } + PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec
(op); } + PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode *op) final { + return BinaryVec(op); + } + PrimExpr VisitExpr_(const FloorModNode *op) final { + return BinaryVec(op); + } + PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec(op); } + + PrimExpr VisitExpr_(const NotNode *op) final { + PrimExpr a = this->VisitExpr(op->a); + if (a.same_as(op->a)) { + return tvm::ffi::GetRef(op); + } else { + return !(a); + } + } + + PrimExpr VisitExpr_(const RampNode *op) final { + PrimExpr base = this->VisitExpr(op->base); + PrimExpr stride = this->VisitExpr(op->stride); + ICHECK(!base.dtype().is_scalable_vector()) + << "Creating scalable vectors from existing vectors is not supported."; + ICHECK(!stride.dtype().is_scalable_vector()) + << "Ramp stride with scalable dtype is not supported"; + if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { + ICHECK(op->lanes->IsInstance()) + << "Vectorizing over existing scalable vectors is not supported."; + const RampNode *base_ramp = base.as(); + int op_lanes = static_cast(Downcast(op->lanes)->value); + int base_ramp_lanes = + static_cast(Downcast(base_ramp->lanes)->value); + if (analyzer_.CanProve(base_ramp->stride == + stride * + make_const(stride.dtype(), base_ramp_lanes))) { + return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); + } + } + int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); + base = BroadcastTo(base, lanes, false); + stride = BroadcastTo(stride, lanes, false); + Array elems; + for (int i = 0; i < lanes; ++i) { + elems.push_back(Ramp(Shuffle::ExtractElement(base, i), + Shuffle::ExtractElement(stride, i), op->lanes)); + } + return Shuffle::Concat(elems); + } + + PrimExpr VisitExpr_(const BroadcastNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.dtype().is_scalable_or_fixed_length_vector()) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + if (value.same_as(op->value)) { + return tvm::ffi::GetRef(op); + } else { + return Broadcast(op->value, op->lanes); + } + } + + PrimExpr VisitExpr_(const SelectNode *op) final { + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr t = this->VisitExpr(op->true_value); + PrimExpr f = this->VisitExpr(op->false_value); + if (cond.same_as(op->condition) && t.same_as(op->true_value) && + f.same_as(op->false_value)) { + return tvm::ffi::GetRef(op); + } else { + int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes); + bool is_scalable = cond.dtype().is_scalable_vector() || + t.dtype().is_scalable_vector() || + f.dtype().is_scalable_vector(); + return Select(BroadcastTo(cond, lanes, is_scalable), + BroadcastTo(t, lanes, is_scalable), + BroadcastTo(f, lanes, is_scalable)); + } + } + + PrimExpr VisitExpr_(const CastNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return tvm::ffi::GetRef(op); + } else { + if (value.dtype().is_scalable_vector()) { + return Cast(op->dtype.with_scalable_vscale_factor( + value.dtype().vscale_factor()), + value); + } else { + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + } + } + } + + PrimExpr VisitExpr_(const FloatImmNode *op) final { + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const IntImmNode *op) final { + return tvm::ffi::GetRef(op); + } + + PrimExpr VisitExpr_(const StringImmNode *op) final { + return tvm::ffi::GetRef(op); + } + + // Variable + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = tvm::ffi::GetRef(op); + + if (var.same_as(var_)) { + return ramp_; + } + auto it = let_var_map_.find(var); + if (it != let_var_map_.end()) { + return it->second; + } else { + return std::move(var); + } + } + // IfThenElse expr + PrimExpr MutateIfThenElseExpr_(const CallNode *op) { + PrimExpr cond = this->VisitExpr(op->args[0]); + if (cond.dtype().is_scalable_or_fixed_length_vector()) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + PrimExpr t = this->VisitExpr(op->args[1]); + PrimExpr f = this->VisitExpr(op->args[2]); + if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && + f.same_as(op->args[2])) { + return tvm::ffi::GetRef(op); + } else { + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(t_lanes, f_lanes); + bool is_scalable = + t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); + t = BroadcastTo(t, lanes, is_scalable); + f = BroadcastTo(f, lanes, is_scalable); + if (is_scalable) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, + {cond, t, f}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + } + } + } + // Reinterpret expr + PrimExpr MutateReinterpretExpr_(const CallNode *op) { + ICHECK(op->op.same_as(builtin::reinterpret())); + PrimExpr value = this->VisitExpr(op->args[0]); + if (value.same_as(op->args[0])) { + return tvm::ffi::GetRef(op); + } else { + int lanes = value.dtype().get_lanes_or_vscale_factor(); + if (value.dtype().is_scalable_vector()) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, + {value}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {value}); + } + } + } + // Call + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::if_then_else())) { + return MutateIfThenElseExpr_(op); + } else if (op->op.same_as(builtin::texture2d_load())) { + int lane = 0; + Array fcd = MutateArray({op->args.back()}, &lane); + auto new_args = op->args; + new_args.pop_back(); + new_args.push_back(fcd[0]); + return Call(op->dtype.with_lanes(4), op->op, new_args); + } else if (op->op.same_as(builtin::texture2d_store())) { + int lane = 0; + // Vectorize the value to store + Array value{op->args.back()}; + Array mutated_value = MutateArray(value, &lane); + Array new_args{op->args[0], op->args[1], op->args[2], + mutated_value[0]}; + return Call(op->dtype.with_lanes(lane), op->op, new_args); + } else if (op->op.same_as(builtin::reinterpret())) { + return MutateReinterpretExpr_(op); + } + auto optional_op = op->op.as(); + bool vectorizable = optional_op && + op_vectorizable_.get(optional_op.value(), false) && + !op->dtype.is_scalable_vector(); + if (!vectorizable) { + // Cannot vectorize this op + Array new_args; + for (auto arg : op->args) { + auto new_arg = this->VisitExpr(arg); + if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + new_args.push_back(new_arg); + } + if (op->args.same_as(new_args)) { + return tvm::ffi::GetRef(op); + } else { + return Call(op->dtype, op->op, new_args); + } + } else { + int lane = 0; + Array new_args = MutateArray(op->args, &lane); + // normal code path. + if (op->args.same_as(new_args)) { + return tvm::ffi::GetRef(op); + } else { + return Call(op->dtype.with_lanes(lane), op->op, new_args); + } + } + } + // BufferLoad + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto load = tvm::ffi::GetRef(op); + + auto fmutate = [this](const PrimExpr &index) { + return this->VisitExpr(index); + }; + Array indices = op->indices.Map(fmutate); + + if (!indices.same_as(op->indices)) { + BufferLoadNode *writer = load.CopyOnWrite(); + writer->indices = indices; + LegalizeBufferLoadDType(writer); + } + + return std::move(load); + } + // Let + PrimExpr VisitExpr_(const LetNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to construct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_var_map_.find(op->var); + if (it != let_var_map_.end()) { + ICHECK(deep_equal_(it->second, value)) + << "Let cannot bind the same var to two different values"; + } + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { + Var new_var(op->var->name_hint, value.dtype()); + let_var_map_[op->var] = new_var; + // Record mapping from the new var to its bound value + let_value_binding_[new_var] = value; + return Let(new_var, value, this->VisitExpr(op->body)); + } else { + let_var_map_[op->var] = op->var; + PrimExpr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + return Let(op->var, value, body); + } + } + } + // BufferStore + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = tvm::ffi::GetRef(op); + + auto fmutate = [this](const PrimExpr &index) { + return this->VisitExpr(index); + }; + Array indices = op->indices.Map(fmutate); + + PrimExpr value = this->VisitExpr(op->value); + + if (!indices.same_as(op->indices) || !value.same_as(op->value)) { + ICHECK(!op->buffer->dtype.is_scalable_vector()) + << "Vectorizing over scalable buffer elements is not supported in " + "vectorizer."; + // How many lanes of indexing are present in the index and + // buffer element type, excluding the last index. + int other_index_lanes = op->buffer->dtype.lanes(); + for (size_t i = 0; i < indices.size() - 1; i++) { + other_index_lanes *= indices[i].dtype().lanes(); + // Only allow the last index to be scalable + ICHECK(!indices[i].dtype().is_scalable_vector()) + << "Only the last index can be scalable."; + } + + // The total number of lanes of indexing, including the last index. + auto last_index_dtype = indices[indices.size() - 1].dtype(); + int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor(); + int index_lanes = other_index_lanes * lanes_in_last_index; + + // The total number of lanes in this store operation. Either + // the index or the value will be broadcast out to this number + // of lanes, depending on which has more lanes. + int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); + bool is_last_index_scalable = last_index_dtype.is_scalable_vector(); + int total_lanes = std::max(index_lanes, value_dtype_lanes); + + ICHECK_EQ(total_lanes % other_index_lanes, 0) + << "When storing to buffer " << op->buffer->name + << ", cannot produce " << total_lanes + << " lanes of storage location by changing the last index."; + int last_index_lanes = total_lanes / other_index_lanes; + + // Broadcast the last index such that the total number of index + // lanes matches the desired number. + indices.Set(indices.size() - 1, + BroadcastTo(indices[indices.size() - 1], last_index_lanes, + is_last_index_scalable)); + + auto writer = store.CopyOnWrite(); + writer->indices = indices; + writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); + } + + return std::move(store); + } + // For + Stmt VisitStmt_(const ForNode *op) final { + if (op->kind == ForKind::kVectorized) { + LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; + } + ICHECK(is_zero(op->min)); + ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); + PrimExpr extent = this->VisitExpr(op->extent); + if (extent.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(tvm::ffi::GetRef(op)); + } + Stmt body = this->VisitStmt(op->body); + if (extent.same_as(op->extent) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + return For(op->loop_var, op->min, extent, op->kind, body, + op->thread_binding, op->annotations); + } + } + // IfThenElse + Stmt VisitStmt_(const IfThenElseNode *op) final { + ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); + PrimExpr condition = this->VisitExpr(op->condition); + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(tvm::ffi::GetRef(op)); + } + Stmt then_case = this->VisitStmt(op->then_case); + Optional else_case = std::nullopt; + if (op->else_case) { + else_case = this->VisitStmt(op->else_case.value()); + } + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return tvm::ffi::GetRef(op); + } else { + return IfThenElse(condition, then_case, else_case); + } + } + // While + Stmt VisitStmt_(const WhileNode *op) final { + LOG(FATAL) << "A while loop inside a vectorized loop not supported."; + } + // LetStmt + Stmt VisitStmt_(const LetStmtNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + ICHECK(!let_var_map_.count(op->var)) + << "SSA violation, a single var is binded twice"; + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { + Var new_var(op->var->name_hint, value.dtype()); + let_var_map_[op->var] = new_var; + // Record mapping from the new var to its bound value + let_value_binding_[op->var] = op->value; + let_value_binding_[new_var] = value; + + return LetStmt(new_var, value, this->VisitStmt(op->body)); + } else { + let_var_map_[op->var] = op->var; + let_value_binding_[op->var] = value; + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return tvm::ffi::GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } + } + + // Allocate + Stmt VisitStmt_(const AllocateNode *op) final { + // Mutate the condition + PrimExpr condition = this->VisitExpr(op->condition); + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc of " + << op->buffer_var->name_hint; + return Scalarize(tvm::ffi::GetRef(op)); + } + + return StmtMutator::VisitStmt_(op); + } + + // scalarize the statement + Stmt Scalarize(Stmt stmt) { + Var idx(var_->name_hint + "_s", var_->dtype); + // Find all Vars in stmt that are keys in let_value_binding_ + std::unordered_set used_let_bound_vars; + PostOrderVisit(stmt, [this, &used_let_bound_vars](const ObjectRef &node) { + if (const auto *v = node.as()) { + Var var = GetRef(v); + if (let_value_binding_.count(var)) { + used_let_bound_vars.insert(var); + } + } + }); + stmt = Substitute(stmt, {{var_, idx}}); + + if (!used_let_bound_vars.empty()) { + for (const auto &v : used_let_bound_vars) { + // Bind the existing var v to its value around the stmt scope + auto new_value = Substitute(let_value_binding_.at(v), {{var_, idx}}); + stmt = LetStmt(v, new_value, stmt); + } + } + + return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); + } + +private: + // analyzer + arith::Analyzer analyzer_; + // deep equal + ExprDeepEqual deep_equal_; + // variable to be replaced + Var var_; + // the lanes. + PrimExpr var_lanes_; + // ramp representing the var. + PrimExpr ramp_; + // flag to mark requirement of scalarization. + bool need_scalarize_{false}; + // Let var mapping + std::unordered_map let_var_map_; + // Let value binding: map new_var -> value + std::unordered_map + let_value_binding_; + // vectorizable property + OpAttrMap op_vectorizable_ = + Op::GetAttrMap("TVectorizable"); + + // mutate array, with given lane requirement + // when finished, p_lane updates the lane requirement. + Array MutateArray(Array arr, int *p_lanes) { + if (arr.empty()) + return arr; + int &lanes = *p_lanes; + bool changed = false; + std::vector new_arr(arr.size()); + for (size_t i = 0; i < arr.size(); i++) { + PrimExpr old_elem = arr[i]; + PrimExpr new_elem = this->VisitExpr(old_elem); + if (!new_elem.same_as(old_elem)) + changed = true; + new_arr[i] = new_elem; + lanes = std::max(lanes, new_elem.dtype().lanes()); + } + + for (size_t i = 0; i < arr.size(); ++i) { + if (new_arr[i].dtype().lanes() != lanes) { + new_arr[i] = BroadcastTo(new_arr[i], lanes, false); + changed = true; + } + } + if (!changed) + return arr; + return Array(new_arr); + } + template PrimExpr BinaryVec(const T *op) { + static_assert(std::is_same::value, + "constraint"); + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return tvm::ffi::GetRef(op); + } else { + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); + bool is_scalable = + a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return TOp(BroadcastTo(a, lanes, is_scalable), + BroadcastTo(b, lanes, is_scalable)); + } + } + template + PrimExpr AddSubVec(const T *op, FCompute fcompute) { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return tvm::ffi::GetRef(op); + } else { + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); + if (lanes != 1) { + const RampNode *b_ramp = b.as(); + const RampNode *a_ramp = a.as(); + if (a.dtype().is_scalar() && b_ramp) { + return Ramp( + fcompute(a, b_ramp->base), + fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), + b_ramp->lanes); + } + if (b.dtype().is_scalar() && a_ramp) { + return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + } + } + bool is_scalable = + a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return fcompute(BroadcastTo(a, lanes, is_scalable), + BroadcastTo(b, lanes, is_scalable)); + } + } +}; + +inline bool TargetHasSVE() { + return Target::Current()->GetFeature("has_sve").value_or(false); +} + +class LoopVectorizer : public StmtMutator { +public: + Stmt VisitStmt_(const ForNode *op) final { + if (op->kind == ForKind::kVectorized) { + auto *extent_as_int = op->extent.as(); + + if (!extent_as_int || extent_as_int->value < 1) { + bool is_scalable_expr = + CheckContains::ExprContains(op->extent, arith::IsVScaleCall); + ICHECK(is_scalable_expr && TargetHasSVE()) + << "Failed to vectorize loop with extent " << op->extent + << " for target " << Target::Current(); + } + ICHECK(is_zero(op->min)); + return TLVectorizer::Vectorize(op->loop_var, op->extent, op->body); + } else { + return StmtMutator::VisitStmt_(op); + } + } +}; + +class VectorizeSkipper : public StmtMutator { +public: + Stmt VisitStmt_(const ForNode *op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + op = stmt.as(); + if (op->kind == ForKind::kVectorized) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body); + } else { + return stmt; + } + } +}; + +Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } + +tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + if (enable_vectorize) { + n->body = tvm::tl::LoopVectorizer()(std::move(n->body)); + } else { + n->body = tvm::tl::VectorizeSkipper()(std::move(n->body)); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/warp_specialized_rewriter.cc b/tilelang/original/src/transform/warp_specialized_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e891d8551bd2d52e2f0ea7233256bf8cecf0d17 --- /dev/null +++ b/tilelang/original/src/transform/warp_specialized_rewriter.cc @@ -0,0 +1,1325 @@ +/*! + * \file warp_specialized_rewriter.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) + */ + +#include "warp_specialized_rewriter.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace runtime; +using arith::IRVisitorWithAnalyzer; + +struct LoopInfo { + Var loop_var; + PrimExpr extent; + PrimExpr min; +}; + +enum class Role : uint8_t { kConsumer, kProducer, kBoth }; + +class ProducerBufferDetector : public StmtExprVisitor { +public: + ProducerBufferDetector( + std::unordered_set cur_producer_buffers) + : cur_producer_buffers_(std::move(cur_producer_buffers)) {} + + void clear() { has_producer_buffer_ = false; } + + void VisitExpr_(const CallNode *call) final { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + has_producer_buffer_ = true; + } + StmtExprVisitor::VisitExpr_(call); + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (cur_producer_buffers_.count(op->buffer.get())) { + has_producer_buffer_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool has_producer_buffer_ = false; + std::unordered_set cur_producer_buffers_; +}; + +class ProducerUsedBufferFinder : public StmtExprVisitor { +public: + auto FindProducerusedBuffer(const Stmt &stmt) { + producer_buffers_.clear(); + let_var_to_expr_.clear(); + std::unordered_set last_producer_buffers_; + for (;;) { + VisitStmt(stmt); + if (producer_buffers_ == last_producer_buffers_) { + break; + } + last_producer_buffers_ = producer_buffers_; + } + return producer_buffers_; + } + + void InsertBuffer(const PrimExpr &expr) { + // Find the buffer that is used in the condition + VarUseDefAnalyzer usage(Array{}); + usage(expr); + for (const auto &buffer : usage.buffer_use_count_) { + producer_buffers_.insert(buffer.first); + } + // Also collect buffers through let bindings + CollectBuffersFromExpr(expr); + } + + // Collect buffers from expression, following let bindings + void CollectBuffersFromExpr(const PrimExpr &expr) { + PostOrderVisit(expr, [this](const ObjectRef &node) { + if (auto bl = node.as()) { + producer_buffers_.insert(bl->buffer.get()); + } else if (auto var_node = node.as()) { + auto var = tvm::ffi::GetRef(var_node); + auto it = let_var_to_expr_.find(var.get()); + if (it != let_var_to_expr_.end()) { + CollectBuffersFromExpr(it->second); + } + } + }); + } + + void VisitStmt_(const LetStmtNode *op) final { + let_var_to_expr_[op->var.get()] = op->value; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const IfThenElseNode *op) final { + ProducerBufferDetector producer_buffer_detector(producer_buffers_); + producer_buffer_detector(op->then_case); + if (op->else_case.defined()) { + producer_buffer_detector(op->else_case.value()); + } + if (producer_buffer_detector.has_producer_buffer_) { + InsertBuffer(op->condition); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const ForNode *op) final { + ProducerBufferDetector producer_buffer_detector(producer_buffers_); + producer_buffer_detector(op->body); + if (producer_buffer_detector.has_producer_buffer_) { + InsertBuffer(op->min); + InsertBuffer(op->extent); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (producer_buffers_.count(op->buffer.get())) { + InsertBuffer(op->value); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + for (auto arg : op->args) { + // Collect buffers from args, including through let bindings + CollectBuffersFromExpr(arg); + } + } + } + +private: + std::unordered_set producer_buffers_; + std::unordered_map let_var_to_expr_; +}; + +class WarpSpecializedRoleMarker : public StmtVisitor { +public: + WarpSpecializedRoleMarker(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + + void Prepare(const Stmt &stmt) { + ProducerUsedBufferFinder finder; + producer_buffers_ = finder.FindProducerusedBuffer(stmt); + } + + Role GetRole(const StmtNode *stmt) const { + auto it = map_.find(stmt); + ICHECK(it != map_.end()); + return it->second; + } + + Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); } + + void VisitStmt_(const EvaluateNode *op) final { + Role role = Role::kConsumer; + if (auto call = op->value.as()) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + role = Role::kProducer; + has_bulk_copy_ = true; + } + if (call->op.same_as(loop_break())) { + role = Role::kBoth; + } + } + SetRole(op, role); + } + + void VisitStmt_(const BufferStoreNode *op) final { + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + bool is_shared_store = scope.rank == StorageRank::kShared; + if (producer_buffers_.count(op->buffer.get())) { + SetRole(op, Role::kBoth); + return; + } + if (!is_shared_store) { + SetRole(op, Role::kConsumer); + return; + } + + // Check reads from global + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ tvm::ffi::GetRef(op)); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto reads = access[0]; + Role role = Role::kProducer; + if (reads.empty()) + role = Role::kConsumer; + for (auto read : reads) { + if (read->buffer.scope() != "global") { + role = Role::kConsumer; + break; + } + } + if (role == Role::kProducer) + has_simt_copy_ = true; + SetRole(op, role); + } + + void VisitStmt_(const SeqStmtNode *op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->seq[0]); + for (auto stmt : op->seq) { + if (role != GetRole(stmt)) { + role = Role::kBoth; + break; + } + } + SetRole(op, role); + } + + void VisitStmt_(const IfThenElseNode *op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->then_case); + if (op->else_case.defined()) { + auto role_else = GetRole(op->else_case.value()); + if (role != role_else) + role = Role::kBoth; + } + SetRole(op, role); + } + + void VisitStmt_(const BlockRealizeNode *op) final { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->block)); + } + + void VisitStmt_(const AllocateNode *op) final { + StmtVisitor::VisitStmt_(op); + Role role = Role::kConsumer; + SetRole(op, role); + } + + template void HandleBodyStmt(const NodeType *op) { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->body)); + } + + void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const WhileNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); } + + bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + + bool HasSimtCopy() { return has_simt_copy_; } + +private: + void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; } + Map buffer_data_to_buffer_; + std::unordered_map map_; + bool has_simt_copy_ = false; + bool has_bulk_copy_ = false; + std::unordered_set producer_buffers_; +}; + +static PrimExpr makeGetBarrier(PrimExpr barrier_id) { + return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)}); +} + +static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1, + const PrimExpr &pred = 1) { + Array args = {makeGetBarrier(std::move(barrier_id))}; + if (cta_id != -1) { + args.push_back(cta_id); + args.push_back(pred); + } + return Evaluate( + Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args)); +} + +static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { + auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), + {makeGetBarrier(std::move(barrier_id))}); + return Evaluate(call); +} + +static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { + auto call = Call(DataType::Handle(), mbarrier_wait_parity(), + {makeGetBarrier(std::move(barrier_id)), std::move(parity)}); + return Evaluate(call); +} + +class ProducerTraitsCollector : public StmtExprVisitor { +public: + ProducerTraitsCollector() { Clear(); } + + void Clear() { has_simt_copy = false; } + + void Collect(const Stmt &stmt) { VisitStmt(stmt); } + + bool HasSimtCopy() { return has_simt_copy; } + +private: + void VisitStmt_(const IfThenElseNode *op) final { + bool old_in_if_cond = in_if_cond_; + in_if_cond_ = true; + VisitExpr(op->condition); + in_if_cond_ = old_in_if_cond; + + VisitStmt(op->then_case); + if (op->else_case.defined()) { + VisitStmt(op->else_case.value()); + } + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (!in_if_cond_) { + has_simt_copy = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool has_simt_copy{}; + bool in_if_cond_ = false; +}; + +// Rewrite the producer Stmt to use the correct barrier index +class MbarrierRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { + MbarrierRewriter rewriter; + rewriter.producer_barrier_idx_ = std::move(barrier_id); + return rewriter(std::move(stmt)); + } + +private: + PrimExpr VisitExpr_(const CallNode *op) final { + auto call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + auto mbar = makeGetBarrier(producer_barrier_idx_); + auto arg0 = call->args[0].as(); + // Check if this is a 1D TMA load + auto is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + call->op.same_as(tma_load()); + if (is_1d_tma_load) { + call.CopyOnWrite()->args.Set(2, mbar); + } else { + Call access_ptr = Downcast(call->args[2]); + ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); + call.CopyOnWrite()->args.Set(1, mbar); + } + } + return call; + } + PrimExpr producer_barrier_idx_; +}; + +class ThreadIdxRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced, + PrimExpr thread_extent, bool do_shuffle = false) { + auto rewriter = + ThreadIdxRewriter(std::move(thread_var), std::move(replaced), + std::move(thread_extent), do_shuffle); + return rewriter(std::move(stmt)); + } + +private: + ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent, + bool do_shuffle) + : thread_var_(std::move(thread_var)), replaced_(std::move(replaced)), + thread_extent_(std::move(thread_extent)), do_shuffle_(do_shuffle) {} + + PrimExpr VisitExpr_(const VarNode *var) final { + if (var == thread_var_.get()) { + return replaced_; + } else { + return StmtExprMutator::VisitExpr_(var); + } + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { + auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) { + return parameter == thread_var_.get(); + }; + maybe_thread_opt_ = false; + if (!op->else_case.defined() && op->condition.as() && + UsesVar(op->condition, f_uses_thread_index) && + !(UsesVar(op->then_case, f_uses_thread_index))) { + auto eq_op = Downcast(op->condition); + if (eq_op->a.as() == thread_var_.get() || + eq_op->b.as() == thread_var_.get()) { + maybe_thread_opt_ = true; + } + auto then_case = StmtExprMutator::VisitStmt(op->then_case); + maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_; + has_tma_op_ = false; + if (maybe_thread_opt_) { + return IfThenElse( + Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}), + StmtExprMutator::VisitStmt(op->then_case), std::nullopt); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl::tma_load()) || + op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_store())) { + has_tma_op_ = true; + } + return StmtExprMutator::VisitExpr_(op); + } + + Var thread_var_; + PrimExpr replaced_; + PrimExpr thread_extent_; + bool maybe_thread_opt_ = false; + bool do_shuffle_; + bool has_tma_op_ = false; +}; + +Block MakeGroupBlock(const Stmt &stmt, + const Map &annotations) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ stmt, + /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{}, + /*annotations=*/annotations); + return block; +} + +struct OpInfo { + int group_size{}, order{}, stage{}; + std::vector group; +}; +struct PipelineInfo { + std::vector op_infos; + + PipelineInfo() = default; + PipelineInfo(const Array> &group_info, + const Array &order_info, + const Array &stage_info) { + int n = static_cast(group_info.size()); + ICHECK(n == static_cast(order_info.size())); + ICHECK(n == static_cast(stage_info.size())); + // int cur_id = 0; + for (int i = 0; i < n; i++) { + OpInfo op_info; + op_info.group_size = group_info[i].size(); + for (int j = 0; j < op_info.group_size; j++) { + op_info.group.push_back(group_info[i][j].as()->value); + } + op_info.order = order_info[i].as()->value; + op_info.stage = stage_info[i].as()->value; + op_infos.push_back(op_info); + } + } + + PipelineInfo(const PipelineInfo &other) { + for (const auto &op_info : other.op_infos) { + op_infos.push_back(op_info); + } + } + + std::pair FindStmt(int stmt_idx) { + for (size_t i = 0; i < op_infos.size(); i++) { + for (size_t j = 0; j < op_infos[i].group.size(); j++) { + if (op_infos[i].group[j] == stmt_idx) { + return std::make_pair(i, j); + } + } + } + return std::make_pair(-1, -1); + } + + void UpdateOrder(int order) { + for (int i = 0; i < static_cast(op_infos.size()); i++) { + if (op_infos[i].order >= order && op_infos[i].order > 0) { + op_infos[i].order++; + } + } + } + + int SplitOp(int stmt_idx) { + auto pair = FindStmt(stmt_idx); + int op_idx = pair.first; + int inner_idx = pair.second; + ICHECK(op_idx != -1); + ICHECK(inner_idx != -1); + OpInfo half0; + OpInfo half1; + // The order to do sync + int sync_order = op_infos[op_idx].order + 1; + UpdateOrder(sync_order); + + half0.group_size = inner_idx + 1; + half0.order = op_infos[op_idx].order; + half0.stage = op_infos[op_idx].stage; + for (int i = 0; i <= inner_idx; i++) { + half0.group.push_back(op_infos[op_idx].group[i]); + } + half1.group_size = op_infos[op_idx].group_size - inner_idx - 1; + half1.order = op_infos[op_idx].order + 2; + half1.stage = op_infos[op_idx].stage; + for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) { + half1.group.push_back(op_infos[op_idx].group[i]); + } + op_infos.erase(op_infos.begin() + op_idx); + if (half0.group_size > 0) { + op_infos.insert(op_infos.begin() + op_idx, half0); + } + if (half1.group_size > 0) { + UpdateOrder(half1.order); + op_infos.insert(op_infos.begin() + op_idx + 1, half1); + } + return sync_order; + } + + void PrintPipelineInfo() { + std::cout << "Print op_infos:" << '\n'; + for (size_t i = 0; i < op_infos.size(); i++) { + std::cout << i << " " << op_infos[i].group_size << " " + << op_infos[i].order << " " << op_infos[i].stage << '\n'; + } + std::cout << "End of print" << '\n'; + } +}; + +class GroupOpRewriter : public StmtExprMutator { +public: + GroupOpRewriter(const PipelineInfo &pipeline_info) + : pipeline_info_(pipeline_info) {} + +private: + Stmt VisitStmt_(const ForNode *op) final { + Map annotations; + annotations.Set(String("stmt_group"), Integer(1)); + auto original_node = (op->body).as(); + if (!original_node) { + return tvm::ffi::GetRef(op); + } + Array new_body; + int cur_id = 0; + for (int i = 0; i < static_cast(pipeline_info_.op_infos.size()); i++) { + if (pipeline_info_.op_infos[i].group_size == 0) + continue; + Array block_stmt; + for (int j = 0; + j < static_cast(pipeline_info_.op_infos[i].group_size); j++) { + // ICHECK(group_info_[i][j].as()); + // int index = + // static_cast(group_info_[i][j].as()->value); + ICHECK(original_node->seq[cur_id].as()); + auto block = original_node->seq[cur_id].as(); + // TODO: handle nested seqstmt + block_stmt.push_back(block->body); + cur_id++; + } + new_body.push_back(MakeGroupBlock( + block_stmt.size() == 1 ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + } + Array order_anno; + Array stage_anno; + for (const auto &op_info : pipeline_info_.op_infos) { + order_anno.push_back(Integer(op_info.order)); + stage_anno.push_back(Integer(op_info.stage)); + } + Map for_annotations = op->annotations; + for_annotations.erase("tl_pipeline_group"); + for_annotations.Set("software_pipeline_order", order_anno); + for_annotations.Set("software_pipeline_stage", stage_anno); + For new_for = + For(op->loop_var, op->min, op->extent, op->kind, + new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), + op->thread_binding, for_annotations); + return new_for; + } + + PipelineInfo pipeline_info_; +}; + +class WgMMACollector : public StmtExprVisitor { +public: + WgMMACollector() = default; + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl_gemm()) || op->op.same_as(tl_gemm_sp())) { + auto op_name = std::string(op->args[0].as()->value); + if (has_wgmma_) { + has_wgmma_ = + op_name.find("false") == std::string::npos && !in_if_scope_; + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode *op) final { + in_if_scope_ = true; + StmtExprVisitor::VisitStmt(op->then_case); + if (op->else_case.defined()) { + StmtExprVisitor::VisitStmt(op->else_case.value()); + } + in_if_scope_ = false; + } + + static bool HasWgMMA(const Stmt &stmt) { + auto collector = WgMMACollector(); + collector(stmt); + return collector.has_wgmma_; + } + + bool has_wgmma_{true}; + bool in_if_scope_{false}; +}; + +class WSCodeEmitter : public StmtMutator { +public: + WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv, + Map buffer_data_to_buffer, + const WarpSpecializedRoleMarker &marker, + bool mbarrier_only = false) + : is_emitting_producer_(is_emitting_producer), + buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + marker_(marker), thread_var_(thread_iv->var), + mbarrier_only_(mbarrier_only) {} + + /** + * @brief Whether a SIMT-style bulk copy was detected. + * + * Returns true when a simulated SIMT (thread-parallel) copy pattern was + * observed during analysis/emission, which can affect barrier insertion and + * copy emission. + * + * @return true if a SIMT copy was detected; false otherwise. + */ + bool hasSimtCopy() const { return has_simt_copy_; } + +private: + template < + typename NodeType> /** + * @brief Filter a statement by its producer/consumer + * role for emission. + * + * Returns one of: + * - the original statement (unchanged) when this + * emitter should emit it, + * - the result of visiting the statement (to descend + * into it) when mbarrier-only mode requires full + * traversal for non-producer roles, + * - an empty evaluate (`Evaluate(0)`) when the + * statement should be omitted. + * + * The decision is based on the role of `op` as + * reported by `marker_`, the emitter mode + * (`is_emitting_producer_`), and the `mbarrier_only_` + * flag. + * + * @param op The statement node to filter; its role is + * queried via `marker_`. + * @return Stmt The statement to place into the emitted + * IR (possibly transformed or an empty evaluate). + */ + Stmt FilterByRole(const NodeType *op) { + Role role = marker_.GetRole(op); + if (mbarrier_only_) { + if (role != Role::kProducer) + return StmtMutator::VisitStmt_(op); + } + if (role == Role::kBoth) { + return StmtMutator::VisitStmt_(op); + } else if ((role == Role::kProducer) == is_emitting_producer_) { + return tvm::ffi::GetRef(op); + } else { + return Evaluate(0); + } + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + + bool has_producer = false; + for (auto stmt : op->seq) { + if (marker_.GetRole(stmt) == Role::kProducer) { + has_producer = true; + break; + } + } + bool need_producer_sync = + has_producer && marker_.GetRole(op) == Role::kBoth; + if (!need_producer_sync) + return FilterByRole(op); + + auto seq_transformed = + op->seq.Map([&](const Stmt &stmt) { return VisitStmt(stmt); }); + + auto map = ExtractSyncPattern(op->seq); + + /* + std::cout << "Print ExtractSyncPattern" << std::endl; + for (int i = 0; i < static_cast(op->seq.size()); i++) { + std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " + << map.release_after[i] << std::endl; + } + std::cout << "Print sync pattern" << std::endl; + for (auto pattern : map.patterns) { + std::cout << pattern.release_idx << " " << pattern.acquire_idx << + std::endl; + } + std::cout << "End of ExtractSyncPattern" << std::endl; + pipeline_info_.PrintPipelineInfo(); + */ + Array new_body; + Map annotations; + annotations.Set(String("stmt_group"), Integer(1)); + + if (is_emitting_producer_) { // producer case + ProducerTraitsCollector collector; + for (int i = 0; i < static_cast(op->seq.size()); i++) { + Array block_stmt = {}; + if (!mbarrier_only_) { + if (marker_.GetRole(op->seq[i]) == Role::kConsumer) + continue; + if (marker_.GetRole(op->seq[i]) == Role::kBoth) { + block_stmt.push_back(seq_transformed[i]); + new_body.push_back( + MakeGroupBlock(block_stmt.size() == 1 + ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + continue; + } + } + + for (int pattern_idx : map.acquire[i]) { + PrimExpr acquire_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + PrimExpr parity = map.is_loop_dependency(pattern_idx) + ? bitwise_xor(parity_, 1) + : parity_; + block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); + } + ICHECK(!map.release[i].empty()); + for (size_t j = 0; j < map.release[i].size(); j++) { + int pattern_idx = map.release[i][j]; + PrimExpr release_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + auto stmt = + MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); + collector.Collect(stmt); + block_stmt.push_back(stmt); + if (collector.HasSimtCopy()) { + block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id)); + has_simt_copy_ = true; + } + if (map.release_after[i][j]) { + block_stmt.push_back(makeArriveBarrier(release_barrier_id)); + for (int s = 0; s < num_stages_; s++) { + released_barrier_.insert(s + num_barriers_ + + num_stages_ * pattern_idx); + } + } + collector.Clear(); + new_body.push_back( + MakeGroupBlock(block_stmt.size() == 1 + ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + } + } + } else { // consumer case + for (int i = 0; i < static_cast(op->seq.size()); i++) { + Array block_stmt = {}; + if (marker_.GetRole(op->seq[i]) == Role::kProducer) + continue; + for (int pattern_idx : map.acquire[i]) { + PrimExpr acquire_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + PrimExpr parity = map.is_loop_dependency(pattern_idx) + ? bitwise_xor(parity_, 1) + : parity_; + block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); + } + block_stmt.push_back(seq_transformed[i]); + for (size_t j = 0; j < map.release[i].size(); j++) { + if (map.release_after[i][j]) { + int pattern_idx = map.release[i][j]; + PrimExpr release_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + block_stmt.push_back(makeArriveBarrier(release_barrier_id)); + for (int s = 0; s < num_stages_; s++) { + released_barrier_.insert(s + num_barriers_ + + num_stages_ * pattern_idx); + } + } + } + new_body.push_back(MakeGroupBlock( + block_stmt.size() == 1 ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + } + // Filter out the producer stmts + int cur_id = 0; + PipelineInfo new_pipeline_info; + for (int i = 0; i < static_cast(pipeline_info_.op_infos.size()); + i++) { + auto op_info = pipeline_info_.op_infos[i]; + bool is_producer = false; + for (int j = 0; j < op_info.group_size; j++) { + if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) { + is_producer = true; + } + cur_id++; + } + if (is_producer) { + ICHECK(op_info.group_size == 1); + } else { + new_pipeline_info.op_infos.push_back(op_info); + } + } + pipeline_info_ = new_pipeline_info; + } + + num_barriers_ += map.patterns.size() * num_stages_; + + ICHECK(!new_body.empty()); + return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); + } + + Stmt VisitStmt_(const ForNode *op) final { + int num_stages = 1; + auto num_stages_anno = op->annotations.Get("num_stages"); + if (num_stages_anno) { + ICHECK(num_stages_anno->as()); + num_stages = static_cast(num_stages_anno->as()->value); + ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; + } + loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min}); + + Array> group_info_array; + Array order_info_array; + Array stage_info_array; + + auto group_anno = op->annotations.Get("tl_pipeline_group"); + if (group_anno) { + group_info_array = Downcast>>(group_anno.value()); + } + auto order_anno = op->annotations.Get("tl_pipeline_order"); + if (order_anno) { + order_info_array = Downcast>(order_anno.value()); + } + auto stage_anno = op->annotations.Get("tl_pipeline_stage"); + if (stage_anno) { + stage_info_array = Downcast>(stage_anno.value()); + } + + PipelineInfo pipeline_info(group_info_array, order_info_array, + stage_info_array); + if (!pipeline_info.op_infos.empty()) { + ICHECK(pipeline_info_.op_infos.empty()) + << "Nested pipeline not supported."; + } + + PrimExpr parity_before = std::move(parity_); + PrimExpr stage_before = std::move(stage_); + int num_stages_before = num_stages_; + PipelineInfo pipeline_info_before = pipeline_info_; + + num_stages_ = num_stages; + pipeline_info_ = pipeline_info; + PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min; + for (size_t i = 1; i < loop_stack_.size(); ++i) { + linear_index = linear_index * loop_stack_[i].extent + + (loop_stack_[i].loop_var - loop_stack_[i].min); + } + stage_ = FloorMod(linear_index, num_stages); + parity_ = FloorMod( + parity_before * op->extent + FloorDiv(linear_index, num_stages), 2); + auto result = FilterByRole(op); + + Stmt grouped_for_node; + if (result.as() && group_anno && !group_info_array.empty() && + !is_emitting_producer_) { + GroupOpRewriter group_op_rewriter(pipeline_info_); + auto for_node = Downcast(result); + grouped_for_node = group_op_rewriter(for_node); + } + + parity_ = std::move(parity_before); + stage_ = std::move(stage_before); + num_stages_ = num_stages_before; + pipeline_info_ = pipeline_info_before; + + // remove pipeline annotation + auto for_node = result.as(); + if (result.as()) { + auto for_node = Downcast(result); + for_node.CopyOnWrite()->annotations.erase("num_stages"); + if (is_emitting_producer_ || group_info_array.empty()) { + for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order"); + for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage"); + } + if (is_emitting_producer_ || !group_anno || group_info_array.empty()) { + loop_stack_.pop_back(); + return for_node; + } + loop_stack_.pop_back(); + return grouped_for_node; + } + loop_stack_.pop_back(); + return result; + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockRealizeNode *op) final { return FilterByRole(op); } + + struct SyncPattern { + int release_idx, acquire_idx; + }; + + struct SyncPatternMap { + std::vector> acquire; + std::vector> release; + std::vector> release_after; + std::vector patterns; + + void resize(size_t n) { + acquire.resize(n); + release.resize(n); + release_after.resize(n); + } + + bool is_loop_dependency(int pattern_idx) { + return patterns[pattern_idx].release_idx > + patterns[pattern_idx].acquire_idx; + } + }; + + std::vector + CreateBaseSyncPairs(const Array &seq_stmt, + const std::vector &is_producer) { + const int n = seq_stmt.size(); + std::vector> reads, writes; + reads.reserve(n); + writes.reserve(n); + for (int i = 0; i < n; i++) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", + /*body*/ seq_stmt[i]); + auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) { + auto var = region->buffer->data; + if (buffer_data_to_buffer_.count(var)) { + read_set.insert(buffer_data_to_buffer_[var].get()); + } else { + read_set.insert(region->buffer.get()); + } + } + for (auto region : access[1]) { + auto var = region->buffer->data; + if (buffer_data_to_buffer_.count(var)) { + write_set.insert(buffer_data_to_buffer_[var].get()); + } else { + write_set.insert(region->buffer.get()); + } + } + reads.push_back(std::move(read_set)); + writes.push_back(std::move(write_set)); + } + + auto intersect_fn = [](const std::set &lhs, + const std::set &rhs) { + for (auto ptr : lhs) + if (rhs.count(ptr)) + return true; + return false; + }; + + std::vector sync_patterns; + // producer_release consumer_acquire, + // inject before the first consumer stmt for each producer + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + if (is_producer[i] != is_producer[j] && + (intersect_fn(writes[i], reads[j]) || + intersect_fn(reads[i], writes[j]))) { + sync_patterns.push_back({i, j}); + break; + } + } + } + + // consumer_release producer_acquire + // valid when is_loop is true + // inject before the earliest producer stmt for each consumer + bool in_loop = !is_zero(parity_); + if (in_loop) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < i; j++) { + if (is_producer[i] != is_producer[j] && + (intersect_fn(writes[i], reads[j]) || + intersect_fn(reads[i], writes[j]))) { + sync_patterns.push_back({i, j}); + break; + } + } + } + } + + return sync_patterns; + } + + static std::vector + RemoveUnusedSyncPatterns(const std::vector &sync_patterns, + const std::vector &is_producer) { + /* + Simplify multiple release-acquire pairs into one + ------------------ + Produce(A) + Produce(B) + Consume(A, B) + ------------------ + [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)] + + Or + ------------------ + Produce(A, B) + Consume(A) + Consume(B) + ------------------ + [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)] + */ + int M = sync_patterns.size(); + std::vector removed(M, false); + for (int i = 0; i < M; i++) { + for (int j = 0; j < M; j++) { + if (is_producer[sync_patterns[i].acquire_idx] == + is_producer[sync_patterns[j].acquire_idx] && + sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx && + sync_patterns[i].release_idx < sync_patterns[j].release_idx) + removed[i] = true; + } + } + + std::vector sync_pattern_cleaned; + sync_pattern_cleaned.reserve(M); + for (int i = 0; i < M; i++) + if (!removed[i]) + sync_pattern_cleaned.push_back(sync_patterns[i]); + + return sync_pattern_cleaned; + } + + SyncPatternMap ExtractSyncPattern(const Array &seq_stmt) { + size_t num_stmts = seq_stmt.size(); + std::vector is_producer; + is_producer.reserve(num_stmts); + for (auto stmt : seq_stmt) { + is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer); + } + + auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer); + auto sync_patterns = + RemoveUnusedSyncPatterns(sync_patterns_base, is_producer); + + // for (auto pattern : sync_patterns) { + // std::cout << pattern.release_idx << " " << pattern.acquire_idx << + // std::endl; + // } + + SyncPatternMap map; + map.resize(num_stmts); + map.patterns = sync_patterns; + + for (size_t i = 0; i < sync_patterns.size(); i++) { + int acquire_idx = sync_patterns[i].acquire_idx; + int release_idx = sync_patterns[i].release_idx; + + map.acquire[acquire_idx].push_back(i); + map.release[release_idx].push_back(i); + map.release_after[release_idx].push_back(true); + } + + std::vector cur_consumer_barrier, cur_producer_barrier; + for (int i = num_stmts - 1; i >= 0; i--) { + if (is_producer[i]) { + if (map.release[i].empty()) { + for (auto pattern_idx : cur_producer_barrier) { + map.release[i].push_back(pattern_idx); + map.release_after[i].push_back(false); + } + } else { + for (auto pattern_idx : map.release[i]) { + cur_producer_barrier.push_back(pattern_idx); + } + } + } else { + if (map.release[i].empty()) { + for (auto pattern_idx : cur_consumer_barrier) { + map.release[i].push_back(pattern_idx); + map.release_after[i].push_back(false); + } + } else { + for (auto pattern_idx : map.release[i]) { + cur_consumer_barrier.push_back(pattern_idx); + } + } + } + } + return map; + } + + const bool is_emitting_producer_; + Map buffer_data_to_buffer_; + std::unordered_set released_barrier_; + const WarpSpecializedRoleMarker &marker_; + + int num_barriers_ = 0; + PrimExpr parity_ = 0; + PrimExpr stage_ = 0; + int num_stages_ = 1; + std::vector loop_stack_; + Var thread_var_; + bool mbarrier_only_ = false; + PipelineInfo pipeline_info_; + friend class WarpSpecializedRewriter; + bool has_simt_copy_ = false; +}; + +class WarpSpecializedRewriter : public StmtExprMutator { +public: + WarpSpecializedRewriter(bool disable_warp_specialized, + bool disable_shuffle_elect) + : disable_warp_specialized_(disable_warp_specialized), + disable_shuffle_elect_(disable_shuffle_elect) {} + static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized, + bool disable_shuffle_elect) { + // Check if function only uses threadIdx.x before proceeding + if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { + LOG(WARNING) << "WarpSpecialize will be disabled because the program " + "uses thread tags other than threadIdx.x." + << "If you want to use warp specialization, please refactor " + "your program to use threadIdx.x only"; + // Return original function unchanged if other thread tags are found + return f; + } + + auto T = WarpSpecializedRewriter(disable_warp_specialized, + disable_shuffle_elect); + T.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : T.buffer_lca_) + T.buffer_data_to_buffer_.Set(buffer->data, buffer); + f.CopyOnWrite()->body = T(f->body); + return f; + } + +private: + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent && + Downcast(op->node)->thread_tag == "threadIdx.x") { + thread_iv_ = Downcast(op->node); + need_update_thread_extent_ = false; + AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (need_update_thread_extent_) { + thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()}; + attr_stmt.CopyOnWrite()->node = thread_iv_; + attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value(); + } + thread_iv_ = {}; + return attr_stmt; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + // If users define a thread binding, we will replace the thread binding with + // threadIdx.x We require the thread binding is threadIdx.x, and the extent is + // the same as the thread extent + Stmt VisitStmt_(const ForNode *op) final { + ICHECK(thread_iv_.defined()); + For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); + if (for_node->kind == ForKind::kThreadBinding) { + ICHECK(for_node->thread_binding.defined()); + String thread_tag = for_node->thread_binding.value()->thread_tag; + ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x"; + Var thread_iv = Downcast(for_node->loop_var); + Stmt new_body = + ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_, 0); + return new_body; + } + return for_node; + } + + Stmt VisitStmt_(const BlockRealizeNode *op) final { + BlockRealize block_realize = + Downcast(StmtExprMutator::VisitStmt_(op)); + if (!thread_iv_.defined()) { + return block_realize; + } + + Block block = block_realize->block; + WarpSpecializedRoleMarker marker(buffer_data_to_buffer_); + marker.Prepare(block); + marker(block); + if (!marker.HasProducer()) { + // Cannot detect any producer here, directly return. + return block_realize; + } + + if (disable_warp_specialized_) { + WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_, + marker, true); + auto code = mbarrier_emitter(block->body); + int num_barriers = mbarrier_emitter.num_barriers_; + Array barrier_num_threads; + barrier_num_threads.reserve(num_barriers); + PrimExpr arrive_thread_count = thread_iv_->dom->extent; + for (int i = 0; i < num_barriers; i++) { + barrier_num_threads.push_back(arrive_thread_count); + } + Stmt init_barrier = Evaluate(Call( + DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads)); + block.CopyOnWrite()->body = SeqStmt({init_barrier, code}); + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); + WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker, + false); + Stmt producer_code = producer(block->body); + Stmt consumer_code = consumer(block->body); + PrimExpr consumer_thread_extent = thread_iv_->dom->extent; + PrimExpr producer_thread_extent = thread_iv_->dom->extent; + // Need one warp-group for bulk-copy only case + if (!marker.HasSimtCopy()) + producer_thread_extent = 128; + + updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; + + producer_code = ThreadIdxRewriter::Rewrite( + producer_code, thread_iv_->var, + thread_iv_->var - consumer_thread_extent, producer_thread_extent, + !disable_shuffle_elect_); + consumer_code = ThreadIdxRewriter::Rewrite( + consumer_code, thread_iv_->var, thread_iv_->var, consumer_thread_extent, + !disable_shuffle_elect_); + need_update_thread_extent_ = true; + + ICHECK(producer.num_barriers_ == consumer.num_barriers_) + << producer.num_barriers_ << " " << consumer.num_barriers_; + int num_barriers = consumer.num_barriers_; + Array barrier_num_threads; + barrier_num_threads.reserve(num_barriers); + for (int i = 0; i < num_barriers; i++) { + PrimExpr arrive_thread_count = + producer.released_barrier_.count(i) + ? (producer.hasSimtCopy() ? producer_thread_extent : 1) + : consumer_thread_extent; + barrier_num_threads.push_back(arrive_thread_count); + } + + Stmt init_barrier = Evaluate(Call( + DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads)); + Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent), + producer_code, consumer_code); + // Add an attr here to handle the partial thread count in ThreadSync pass. + Array ws_partition = {Downcast(producer_thread_extent), + Downcast(consumer_thread_extent)}; + body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body); + + block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + + WarpSpecializedRewriter() = default; + + Map buffer_data_to_buffer_; + Map> buffer_lca_; + Map buffer_remap_; + IterVar thread_iv_; + Optional updated_thread_extent_; + bool need_update_thread_extent_ = false; + bool disable_warp_specialized_ = false; + bool disable_shuffle_elect_ = false; +}; + +using namespace tir::transform; + +tvm::transform::Pass WarpSpecialized() { + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool disable_warp_specialized = + ctx->GetConfig(kDisableWarpSpecialized, Bool(false)).value(); + bool disable_shuffle_elect = + ctx->GetConfig(kDisableShuffleElect, Bool(false)).value(); + bool warp_specialized = WarpSpecializedDetector::Detect(f->body); + + if (!warp_specialized) { + return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, + disable_shuffle_elect); + } else { + auto node = ffi::String("default"); + f.CopyOnWrite()->body = + AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); + return f; + } + }; + return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/warp_specialized_rewriter.h b/tilelang/original/src/transform/warp_specialized_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..01a2474a8355b923dc25ec1ab84ca6c2b2937d98 --- /dev/null +++ b/tilelang/original/src/transform/warp_specialized_rewriter.h @@ -0,0 +1,99 @@ +/*! + * \file warp_specialized_rewriter.h + * \brief tools for warp-specialized-related analysis and transformation + */ + +#pragma once + +#include "arith/ir_visitor_with_analyzer.h" +#include "tir/analysis/var_use_def_analysis.h" +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" +#include "./common/collector.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using namespace runtime; +using arith::IRVisitorWithAnalyzer; + +class WarpSpecializedDetector : public IRVisitorWithAnalyzer { +public: + // return true means this aws will be disabled + static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { + WarpSpecializedDetector detector; + detector.VisitStmt(stmt); + if (detector.has_warp_specialization_) { + LOG(WARNING) << "Auto warp specialization will be disabled because warp " + "specialization is manually enabled"; + return true; + } + if (detector.has_tma_op_ && detector.has_mbarrier_op_) { + LOG(WARNING) << "Auto warp specialization will be disabled because TMA " + "and mbarrier are both present"; + return true; + } + return false; + } + + WarpSpecializedDetector() { + has_tma_op_ = false; + has_mbarrier_op_ = false; + has_warp_specialization_ = false; + } + +private: + void VisitStmt_(const EvaluateNode *op) final { + if (const CallNode *call = op->value.as()) { + if (call->op.same_as(create_list_of_mbarrier()) || + call->op.same_as(mbarrier_wait_parity()) || + call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(builtin::ptx_cp_async_barrier())) { + has_mbarrier_op_ = true; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(set_max_nreg())) { + has_tma_op_ = true; + } + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "warp_specialize" && + op->value.as()->value == 1) { + has_warp_specialization_ = true; + } + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + bool has_tma_op_{false}; + IterVar thread_var_; + bool has_mbarrier_op_{false}; + bool has_warp_specialization_{false}; +}; + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/src/transform/wgmma_sync_rewriter.cc b/tilelang/original/src/transform/wgmma_sync_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..538b491107acdb73f2b3668c5d20ff866de20374 --- /dev/null +++ b/tilelang/original/src/transform/wgmma_sync_rewriter.cc @@ -0,0 +1,275 @@ +/*! + * \file warp_specialized_pipeline.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +bool isGemm(const Stmt &stmt) { + bool is_gemm = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.call_extern"))) { + if (call->args[0].as()) { + std::string name = Downcast(call->args[0])->value; + if (name.find("gemm") != std::string::npos) { + is_gemm = true; + } + } + } + } + return is_gemm; +} + +bool isGemmSync(const Stmt &stmt) { + bool is_gemm_sync = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.call_extern"))) { + if (call->args[0].as()) { + std::string name = Downcast(call->args[0])->value; + if (name.find("warpgroup_wait") != std::string::npos) { + is_gemm_sync = true; + } + } + } + } + return is_gemm_sync; +} + +bool isArriveBarrier(const Stmt &stmt) { + bool is_arrive_barrier = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))) { + is_arrive_barrier = true; + } + } + return is_arrive_barrier; +} + +class WgmmaSyncRewriter : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc f) { + auto T = WgmmaSyncRewriter(); + T.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : T.buffer_lca_) + T.buffer_data_to_buffer_.Set(buffer->data, buffer); + f.CopyOnWrite()->body = T(f->body); + return f; + } + +private: + void CollectWgmmaInfo(const SeqStmtNode *op) { + for (int i = 0; i < static_cast(op->seq.size()); i++) { + auto stmt = op->seq[i]; + if (isGemm(stmt)) { + gemm_stmts_.push_back(stmt); + gemm_stmt_ids_.push_back(i); + bool found_release = false; + for (int j = i + 1; j < static_cast(op->seq.size()); j++) { + auto release_stmt = op->seq[j]; + if (isArriveBarrier(release_stmt)) { + found_release = true; + gemm_release_stmts_.push_back(release_stmt); + break; + } + } + if (!found_release) { + gemm_release_stmts_.push_back(Evaluate(0)); + } + // ICHECK(op->seq.size() > i + 1); + // auto release_stmt = op->seq[i + 1]; + // auto next_call = + // Downcast(release_stmt)->value.as(); + // ICHECK(next_call); + // ICHECK(next_call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))); + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", + /*body*/ op->seq[i]); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) + read_set.insert(region->buffer.get()); + for (auto region : access[1]) + write_set.insert(region->buffer.get()); + gemm_read_buffers_.push_back(read_set); + gemm_write_buffers_.push_back(write_set); + } + } + } + + Stmt VisitStmt_(const ForNode *op) final { + auto order_anno = op->annotations.Get("tl_pipeline_order"); + if (!order_anno) { + return StmtExprMutator::VisitStmt_(op); + } + + CollectWgmmaInfo(op->body.as()); + auto stmt_node = (op->body).as(); + ICHECK(stmt_node); + + auto intersect_fn = [](const std::set &lhs, + const std::set &rhs) { + for (auto ptr : lhs) + if (rhs.count(ptr)) + return true; + return false; + }; + + for (int r = 0; r < static_cast(gemm_stmts_.size()); r++) { + bool found = false; + auto last_stmt = Stmt(); + for (int i = 0; i < static_cast(stmt_node->seq.size()); i++) { + if (stmt_node->seq[i].same_as(gemm_stmts_[r])) { + found = true; + last_stmt = stmt_node->seq[i]; + continue; + } + if (!found) + continue; + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", + /*body*/ stmt_node->seq[i]); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) + read_set.insert(region->buffer.get()); + for (auto region : access[1]) + write_set.insert(region->buffer.get()); + if (intersect_fn(read_set, gemm_write_buffers_[r]) || + intersect_fn(write_set, gemm_read_buffers_[r]) || + intersect_fn(write_set, gemm_write_buffers_[r])) { + break; + } + last_stmt = stmt_node->seq[i]; + } + last_stmts_.push_back(last_stmt); + } + + auto new_seq = Array(); + for (int i = 0; i < static_cast(stmt_node->seq.size()); i++) { + bool remove_ = false; + for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { + if (stmt_node->seq[i].same_as(gemm_release_stmts_[j])) { + remove_ = true; + continue; + } + } + if (remove_) + continue; + auto stmt = stmt_node->seq[i]; + for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { + if (stmt_node->seq[i].same_as(gemm_stmts_[j])) { + auto call = Downcast(stmt)->value.as(); + ICHECK(call); + ICHECK(call->op.same_as(Op::Get("tir.call_extern"))); + ICHECK(call->args[0].as()); + std::string name = Downcast(call->args[0])->value; + std::string new_name = name.substr(0, name.size() - 1) + ", -1>"; + auto new_args = Array(); + new_args.push_back(StringImm(new_name)); + for (int k = 1; k < static_cast(call->args.size()); k++) { + new_args.push_back(call->args[k]); + } + stmt = Evaluate( + Call(DataType::Handle(), builtin::call_extern(), new_args)); + break; + } + } + + new_seq.push_back(stmt); + for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { + if (stmt_node->seq[i].same_as(last_stmts_[j])) { + Array new_args; + new_args.push_back(StringImm("cute::warpgroup_wait<0>")); + new_args.push_back(Integer(j)); + auto new_call = + Call(DataType::Handle(), builtin::call_extern(), new_args); + new_seq.push_back(Evaluate(new_call)); + if (std::count(gemm_release_stmts_.begin(), gemm_release_stmts_.end(), + gemm_release_stmts_[j]) == 1) { + new_seq.push_back(gemm_release_stmts_[j]); + } else { + gemm_release_stmts_[j] = Evaluate(0); + } + } + } + } + + int gemm_count = 0; + int max_sync_index = 0; + for (int i = 0; i < static_cast(new_seq.size()); i++) { + if (isGemm(new_seq[i])) { + gemm_count++; + } else if (isGemmSync(new_seq[i])) { + auto call = Downcast(new_seq[i])->value.as(); + auto sync_index = + static_cast(Downcast(call->args[1])->value); + auto wait_count = gemm_count - sync_index - 1; + if (sync_index > max_sync_index) + max_sync_index = sync_index; + if (sync_index < max_sync_index) { + // new_seq.erase(new_seq.begin() + i); + new_seq.Set(i, Evaluate(0)); + } else { + Array new_args; + std::string call_str = + "cute::warpgroup_wait<" + std::to_string(wait_count) + ">"; + new_args.push_back(StringImm(call_str)); + new_seq.Set(i, Evaluate(Call(DataType::Handle(), + builtin::call_extern(), new_args))); + } + } + } + auto new_for = + For(op->loop_var, op->min, op->extent, op->kind, + new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq)), + op->thread_binding, op->annotations); + return new_for; + } + + WgmmaSyncRewriter() = default; + + Map> buffer_lca_; + Map buffer_data_to_buffer_; + std::vector> gemm_read_buffers_; + std::vector> gemm_write_buffers_; + std::vector gemm_stmts_; + std::vector gemm_release_stmts_; + std::vector last_stmts_; + + std::vector gemm_stmt_ids_; + friend class WgmmaReleaseCollector; +}; + +using namespace tir::transform; + +tvm::transform::Pass RewriteWgmmaSync() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return WgmmaSyncRewriter::Substitute(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/original/testing/.gitkeep b/tilelang/original/testing/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tilelang/original/testing/__pycache__/conftest.cpython-310-pytest-9.0.2.pyc b/tilelang/original/testing/__pycache__/conftest.cpython-310-pytest-9.0.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07d9e0980843987eb6decf7526d3110916126599 Binary files /dev/null and b/tilelang/original/testing/__pycache__/conftest.cpython-310-pytest-9.0.2.pyc differ diff --git a/tilelang/original/testing/conftest.py b/tilelang/original/testing/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..4010e0d83ae84c641151d6dd56dbf40ee42e301f --- /dev/null +++ b/tilelang/original/testing/conftest.py @@ -0,0 +1,41 @@ +import os +import random +import pytest + +os.environ["PYTHONHASHSEED"] = "0" + +random.seed(0) + +try: + import torch +except ImportError: + pass +else: + torch.manual_seed(0) + +try: + import numpy as np +except ImportError: + pass +else: + np.random.seed(0) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """Ensure that at least one test is collected. Error out if all tests are skipped.""" + known_types = { + "failed", + "passed", + "skipped", + "deselected", + "xfailed", + "xpassed", + "warnings", + "error", + } + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: + terminalreporter.write_sep( + "!", + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + ) + pytest.exit("No tests were collected.", returncode=5) diff --git a/tilelang/original/testing/cpp/.gitkeep b/tilelang/original/testing/cpp/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tilelang/original/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/tilelang/original/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py new file mode 100644 index 0000000000000000000000000000000000000000..b26354830a3069b1d1cc6ed6dad967c688489ace --- /dev/null +++ b/tilelang/original/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -0,0 +1,245 @@ +import pytest +import torch +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.intrinsics.mfma_macro_generator import ( + MatrixCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func + +tilelang.testing.set_random_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed=False, + b_transposed=True, + k_pack=1, +): + micro_size_x = micro_size_y = micro_size_k = 16 + + if in_dtype in {T.float8_e4m3fnuz, T.int8}: + micro_size_k = 32 + + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + + chunk = 32 * k_pack + + shared_scope = "shared" + cache_write_shared = False + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size + local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + k_pack=k_pack, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=0): + # Load A into shared memory + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Load B into shared memory + if b_transposed: + T.copy(B[bx * block_N, ko * block_K], B_shared) + else: + T.copy(B[ko * block_K, bx * block_N], B_shared) + + for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local) + + # Perform STMatrix + if cache_write_shared: + mfma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + mfma_emitter.stmatrix( + C_local, + C, + pid_m=by, + pid_n=bx, + ) + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.float32, a_transposed=False, b_transposed=True, k_pack=1): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack) + print(matmul) + kernel = tilelang.compile(matmul) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + if in_dtype == T.int8: + A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + elif in_dtype == T.float8_e4m3fnuz: + A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) + else: + A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + kernel(A, B, C) + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler() + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + if a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + elif a_transposed and not b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + elif not a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + else: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack", + [ + (128, 128, 128, T.float16, T.float16, T.float32, False, True, 1), + (128, 256, 256, T.float16, T.float32, T.float32, False, True, 1), + (128, 256, 256, T.float16, T.float32, T.float32, False, True, 2), + (128, 128, 128, T.int8, T.int32, T.int32, False, True, 1), + (128, 256, 256, T.int8, T.int32, T.int32, False, True, 1), + (128, 256, 256, T.int8, T.int32, T.int32, False, True, 2), + (128, 256, 256, T.int8, T.int32, T.int32, False, False, 1), + (128, 256, 256, T.int8, T.int32, T.int32, False, False, 2), + (128, 128, 128, T.float8_e4m3fnuz, T.float16, T.float32, False, True, 1), + ], +) +@tilelang.testing.requires_rocm +def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack): + assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + k_pack=k_pack, + ) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, b_transposed=False) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, b_transposed=False, k_pack=2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/tilelang/original/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..dc95eb7010b90464676ce0320bd5fdcb17adbf4d --- /dev/null +++ b/tilelang/original/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -0,0 +1,304 @@ +import pytest +import torch +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.transform import simplify_prim_func + +tilelang.testing.set_random_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed=False, + b_transposed=True, + k_pack=1, + b_preshuffle=False, + b_g2l_load=False, +): + micro_size_x = micro_size_y = micro_size_k = 16 + + if in_dtype in {T.float8_e4m3fnuz, T.int8}: + micro_size_k = 32 + + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + + # for preshuffle_b, warp_layout = {1, 4} + if b_preshuffle: + block_row_warps = 1 + block_col_warps = 4 + warp_row_tiles = 64 + warp_col_tiles = 16 + + chunk = 256 * k_pack + + pack_size_k = micro_size_k * k_pack + + shared_scope = "shared" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (K, M) if a_transposed else (M, K) + if b_preshuffle: + B_shape = ( + (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y) + ) + else: + B_shape = (N, K) if b_transposed else (K, N) + + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + if b_preshuffle: + B_shared_shape = ( + (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y) + ) + else: + B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size + local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mfma_emitter = MatrixCorePreshuffleIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) + + num_ko = K // block_K + num_ki = block_K // (k_pack * micro_size_k) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined(num_ko, num_stages=0): + # Load A into shared memory + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Load B into shared memory + if b_g2l_load is False: + if b_transposed: + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k): + B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk] + else: + for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y): + B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj] + + for ki in T.serial(0, num_ki): + # Load A S2L + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + if b_g2l_load: + # Load B G2L + mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx) + else: + # Load B S2L + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local) + + # Perform STMatrix + mfma_emitter.stmatrix( + C_local, + C, + pid_m=by, + pid_n=bx, + ) + + return main + + +def shuffle_weight( + x: torch.Tensor, + layout=(16, 32), + k_pack=1, + is_transpose=False, +) -> torch.Tensor: + IN, IK = layout + BK = IK * k_pack + BN = IN + + N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) + assert N % BN == 0 + assert K % BK == 0 + + x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) + x = x.permute(0, 2, 1, 3) + return x.contiguous() + + +def assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype=T.float32, + a_transposed=False, + b_transposed=True, + k_pack=1, + b_preshuffle=False, + b_g2l_load=False, +): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load) + print(matmul) + kernel = tilelang.compile(matmul) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + if in_dtype == T.int8: + A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + elif in_dtype == T.float8_e4m3fnuz: + A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) + else: + A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + B_preshuffle = B + if b_preshuffle: + B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed) + kernel(A, B_preshuffle, C) + else: + kernel(A, B, C) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler() + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + if a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + elif a_transposed and not b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + elif not a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + else: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + + print(C) + print(ref_c) + + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load", + [ + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, False, 1, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 2, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, False, 2, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, True, 1, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, False, 1, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, True, 2, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, False, 2, True, False), + ], +) +@tilelang.testing.requires_rocm +def test_assert_tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed, + b_transposed, + k_pack, + b_preshuffle, + b_g2l_load, +): + assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + b_g2l_load=b_g2l_load, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/amd/test_tilelang_test_amd.py b/tilelang/original/testing/python/amd/test_tilelang_test_amd.py new file mode 100644 index 0000000000000000000000000000000000000000..4035c299c3f56fd38fb8486190c5bfeeb8215cd4 --- /dev/null +++ b/tilelang/original/testing/python/amd/test_tilelang_test_amd.py @@ -0,0 +1,264 @@ +import pytest +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + k_pack=1, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + vec_size = 4 * k_pack + + @T.prim_func + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared, coalesced_width=vec_size) + else: + T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size) + else: + T.copy(B[k * block_K, bx * block_N], B_shared, coalesced_width=vec_size) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B, k_pack=k_pack) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=0, + num_threads=128, + k_pack=1, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + k_pack=k_pack, + ) + kernel = tl.compile(program, out_idx=[2]) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + return (A @ B).to(torch.__getattribute__(out_dtype)) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + (False, True, 2), + ], +) +@tilelang.testing.requires_rocm +def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack): + run_gemm(1024, 1024, 1024, trans_A, trans_B, T.float16, T.float32, T.float32, 128, 128, 32, k_pack=k_pack) + + +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + (False, True, 2), + ], +) +@tilelang.testing.requires_rocm +def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack): + run_gemm(1024, 1024, 1024, trans_A, trans_B, T.bfloat16, T.float32, T.float32, 128, 128, 32, k_pack=k_pack) + + +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + (False, True, 2), + ], +) +@tilelang.testing.requires_rocm +def test_gemm_bf16bf16f32(trans_A, trans_B, k_pack): + run_gemm(1024, 1024, 1024, trans_A, trans_B, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32, k_pack=k_pack) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + k_pack=1, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + vec_size = 4 * k_pack + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + A_local = T.alloc_fragment(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared, coalesced_width=vec_size) + T.copy(A_shared, A_local) + else: + T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size) + T.copy(A_shared, A_local) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size) + else: + T.copy(B[k * block_K, bx * block_N], B_shared, coalesced_width=vec_size) + T.gemm(A_local, B_shared, C_local, trans_A, trans_B, k_pack=k_pack) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=0, + num_threads=128, + k_pack=1, +): + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + k_pack=k_pack, + ) + kernel = tl.compile(program, out_idx=[2]) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + return (A @ B).to(torch.__getattribute__(out_dtype)) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +# @tilelang.testing.requires_rocm +# def test_gemm_rs_f16f32f32_nt(): +# run_gemm_rs(1024, 1024, 1024, False, False, T.float16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, True, T.float16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, True, T.float16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, False, T.float16, T.float32, T.float32, 128, 128, 32) + +# @tilelang.testing.requires_rocm +# def test_gemm_rs_bf16f32f32_nt(): +# run_gemm_rs(1024, 1024, 1024, False, False, T.bfloat16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, True, T.bfloat16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, True, T.bfloat16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, False, T.bfloat16, T.float32, T.float32, 128, 128, 32) + +# @tilelang.testing.requires_rocm +# def test_gemm_rs_bf16bf16f32_nt(): +# run_gemm_rs(1024, 1024, 1024, False, False, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, False, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/tilelang/original/testing/python/analysis/test_tilelang_fragment_loop_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..99458f1c85521753fcd71ac678abd9b3906845ca --- /dev/null +++ b/tilelang/original/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -0,0 +1,151 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import pytest + + +@tilelang.jit +def simple_invalid_loop(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A): + data_frag[i] = 0 + + return main + + +@tilelang.jit +def nested_invalid_loop(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A // 64): + for j in T.Parallel(64): + data_frag[i * 64 + j] = 0 + + return main + + +@tilelang.jit +def invalid_loop_with_complex_dataflow(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A): + data_frag[64 // 2 + i % 64] = 0 + + return main + + +@tilelang.jit +def valid_loop_not_use_loop_var(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A): # noqa: B007 + for j in T.Parallel(64): + data_frag[j] = 0 # This is valid because we don't use i + + return main + + +@tilelang.jit +def valid_loop_not_frag(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_shared = T.alloc_shared([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_shared[i] = data[tid, i] + + for i in T.Parallel(A): + data_shared[i] = 0 # Valid because this is shared memory + + return main + + +@tilelang.jit +def valid_loop_serial(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_shared = T.alloc_shared([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_shared[i] = data[tid, i] + + for i in T.serial(A): + data_shared[i] = 0 # Valid because this is serial + + return main + + +def test_invalid_loop(): + with pytest.raises(ValueError): + simple_invalid_loop() + with pytest.raises(ValueError): + nested_invalid_loop() + with pytest.raises(ValueError): + invalid_loop_with_complex_dataflow() + + +def test_valid_loop(): + valid_loop_not_use_loop_var() + valid_loop_not_frag() + valid_loop_serial() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/analysis/test_tilelang_nested_loop_checker.py b/tilelang/original/testing/python/analysis/test_tilelang_nested_loop_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..664fda5b81d6622017a6787d47f083a64b3c88f7 --- /dev/null +++ b/tilelang/original/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -0,0 +1,719 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import pytest + +tilelang.testing.set_random_seed() + + +def _require_cuda_tensor(shape, dtype=torch.float32): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randn(*shape, device="cuda", dtype=dtype) + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +""" +Nested Parallel cases: + +T.Parallel + T.Parallel + +Rule: + - continuous parallels is allowed and will be merged into one T.Parallel. + - Non-continuous (e.g. with some statements in the outer-loop) are forbidden. +""" + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_parallels(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block1 // block2): + for j in T.Parallel(block1): + for k in T.Parallel(block2): + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_noncontinuous_parallels(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + B[i] = 0 + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +def test_nested_parallels(): + kernel1 = nested_continuous_parallels(length=256, block=16) + kernel2 = nested_triple_continuous_parallels(length=256, block1=8, block2=2) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + result2 = kernel2(data) + torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5) + + # This is invalid + with pytest.raises(ValueError): + nested_noncontinuous_parallels(length=256, block=16) + + +""" +Nested Pipeline cases: + +T.Pipeline + T.Pipeline + +is OK. +""" + + +def matmul_nested_pipelines( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + for _ in T.Pipelined(extra_pipeline_repeats): + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_nested_pipelines( + order, + stage, + extra_pipeline_repeats, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + trans_A = False + trans_B = False + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 + num_threads = 128 + program = matmul_nested_pipelines( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + extra_pipeline_repeats, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_nested_pipelines(): + run_gemm_nested_pipelines(order=[0, 1, 2], stage=[0, 0, 1], extra_pipeline_repeats=3) + + +""" +Nested serial cases: + +T.serial + T.serial + +is OK. +""" + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_serials(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block): + for j in T.serial(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_noncontinuous_serials(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block): + B[i] = 0 + for j in T.serial(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +def test_nested_serials(): + kernel1 = nested_continuous_serials(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5) + + # This is valid + nested_noncontinuous_serials(length=256, block=16) + + +""" +Mixed serial and Parallel loops: + +(S-P) +T.serial + T.Parallel + +(P-S) +T.Parallel + T.serial + +Rule: + - No Parallel - * - Parallel +""" + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_sp(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_ps(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.serial(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_psp(length=256, block1=8, block2=2, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block1 // block2): + for j in T.serial(block1): + for k in T.Parallel(block2): + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_sps(length=256, block1=8, block2=2, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block1 // block2): + for j in T.Parallel(block1): + for k in T.serial(block2): + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + + return main + + +def test_mixed_sp(): + kernel1 = nested_continuous_sp(length=256, block=16) + kernel2 = nested_continuous_ps(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + result2 = kernel2(data) + torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5) + + # This should be invalid (Undefined behaviour) + with pytest.raises(ValueError): + nested_continuous_psp(length=256, block1=16, block2=8) + + kernel3 = nested_continuous_sps(length=256, block1=8, block2=2) + result3 = kernel3(data) + torch.testing.assert_close(result3, data + 1.0, atol=1e-5, rtol=1e-5) + + +""" +Mixed Pipelined and Parallel loops: + +(Pi-Pa) +T.Pipelined + T.Parallel + +(Pa-Pi) +T.Parallel + T.Pipelined + +Rule: + - Pi-Pa is ok where Pa-Pi is not allowed. + - For more nested cases, refer to the rule of T.Parallel. +""" + + +def matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmul_nested_papipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for _ in T.Parallel(1): + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_mixed_pp( + order, + stage, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 + num_threads = 128 + + program = matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if in_dtype == T.float32: + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + program1 = matmul_nested_papipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + with pytest.raises(ValueError): + tilelang.compile( + program1, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + +def test_mixed_pp(): + run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1]) + + +""" +TiledOp in a T.Parallel is also not permitted. +""" + + +def matmul_with_parallel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + for _ in T.Parallel(1): + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_tiled_op_with_parallel( + order, + stage, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 + num_threads = 128 + + program = matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if in_dtype == T.float32: + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + program1 = matmul_with_parallel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + with pytest.raises(ValueError): + tilelang.compile( + program1, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + +@tilelang.jit(out_idx=[1]) +def tir_op_with_parallel(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = T.max(A[i * block + j], 0.0) + + return main + + +@tilelang.jit(out_idx=[1]) +def customize_op_with_parallel(length=256, block=16, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + T.atomic_add(B[i * block + j], 1.0) + + return main + + +def test_tiled_op_with_parallel(): + run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1]) + + kernel1 = tir_op_with_parallel(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5) + kernel2 = customize_op_with_parallel(length=256, block=16) + result2 = kernel2(data) + torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/arith/test_arith_hard.py b/tilelang/original/testing/python/arith/test_arith_hard.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc859ba63d3f6803e258ce1fdad0711a09568df --- /dev/null +++ b/tilelang/original/testing/python/arith/test_arith_hard.py @@ -0,0 +1,105 @@ +import tilelang.testing +import tilelang.language as T +from tvm.arith import Analyzer +from tvm.ir.expr import Range +from tvm.tir.expr import Not, Or + + +def implies(x, y): + return Or(Not(x), y) + + +def test_hard_prove(): + a = T.Var("a", T.int32) + b = T.Var("b", T.int32) + c = T.Var("c", T.int32) + d = T.Var("d", T.int32) + + def check_expr(expr): + analyzer = Analyzer() + result = analyzer.can_prove(expr, 1) + if not result: + smtlib2 = analyzer.get_smtlib2(expr) + raise AssertionError(f"Failed to prove: {expr}\nSMT-LIB2:\n{smtlib2}") + # assert result, f"Failed to prove: {expr}" + + @T.macro + def complex_expr_1(): + return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b) + + check_expr(complex_expr_1()) + + @T.macro + def complex_expr_2(): + return implies(a < b and b < c and a * d < b * d, b * d < c * d) + + check_expr(complex_expr_2()) + + @T.macro + def complex_expr_3(): + return implies(a >= 0 and a < 128, a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64) + + check_expr(complex_expr_3()) + + @T.macro + def complex_expr_4(): + return implies( + a >= 0 and a < 128, + (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 + 16 - (a // 64 + a % 8 // 4) // 2 * 64) // 512 + == (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 - (a // 64 + a % 8 // 4) // 2 * 64) // 512, + ) + + check_expr(complex_expr_4()) + + +def test_smtlib2(): + import z3 + + a = T.Var("a", T.int32) + b = T.Var("b", T.int32) + c = T.Var("c", T.int32) + + @T.macro + def complex_expr_1(): + return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b) + + e = complex_expr_1() + analyzer = Analyzer() + analyzer.set_z3_timeout_ms(1000) + smtlib2 = analyzer.get_smtlib2(e) + + solver = z3.Solver() + solver.from_string(smtlib2) + assert solver.check() == z3.unsat, f"Expected unsat, got {solver.check()}" + + +def test_bind(): + a = T.Var("a", T.int32) + b = T.Var("b", T.int32) + c = T.Var("c", T.int32) + + analyzer = Analyzer() + analyzer.bind(a, Range(1, 100000)) + analyzer.bind(b, Range(1, 100000)) + analyzer.bind(c, Range(1, 100000)) + + expr = ((b - a) // c) * c + a <= b + smtlib2 = analyzer.get_smtlib2(expr) + try: + result = analyzer.can_prove(expr, 1) + assert result, f"Failed to prove with bindings: {expr}" + except Exception as e: + print(smtlib2) + raise e + + +def test_divmod(): + analyzer = Analyzer() + a = T.Var("a", T.int32) + + assert not analyzer.can_prove(a % 2 % -2 - a % 2 == 0) + assert analyzer.can_prove(a % -2 % 2 - a % 2 == 0) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/arith/test_arith_intset.py b/tilelang/original/testing/python/arith/test_arith_intset.py new file mode 100644 index 0000000000000000000000000000000000000000..e3fc7889ff8cade9633e58da2ec79647b4d679ba --- /dev/null +++ b/tilelang/original/testing/python/arith/test_arith_intset.py @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tilelang import tvm +import tvm.testing +from tvm import te +from tvm import tir +from tvm.arith.analyzer import Analyzer + + +class IntSetChecker: + def __init__(self): + self.analyzer = tvm.arith.Analyzer() + + def verify(self, data, dmap, expected): + res = self.analyzer.int_set(data, dmap) + + def err_msg(): + return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected) + + assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg() + assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg() + + +def test_basic(): + s = tvm.arith.IntervalSet(2, 3) + assert s.min_value.value == 2 + assert s.max_value.value == 3 + + s = tvm.arith.IntSet.single_point(2) + assert s.min_value.value == 2 + assert s.max_value.value == 2 + + +def test_vector(): + base = 10 + stride = 3 + lanes = 2 + s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes)) + assert s.min_value.value == base + assert s.max_value.value == base + stride * (lanes - 1) + + +def test_scalable_vector(): + base = 5 + s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, 2, tvm.tir.vscale() * 4)) + + assert s.min_value.value == base + assert s.max_value.same_as(tvm.arith.int_set.pos_inf()) + + +def test_add_sub(): + ck = IntSetChecker() + x, y = te.var("x"), te.var("y") + ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y)) + ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (1, 21)) + ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (-11, 9)) + + +def test_mul_div(): + ck = IntSetChecker() + x, y = te.var("x"), te.var("y") + + tdiv = tvm.tir.truncdiv + ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) + ck.verify(x * y, {x: tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) + ck.verify(x * 2, {x: tvm.arith.IntervalSet(1, 10)}, (2, 20)) + ck.verify(x * -2, {x: tvm.arith.IntervalSet(1, 10)}, (-20, -2)) + + ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y))) + ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5)) + + fld = tvm.te.floordiv + ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y))) + ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5)) + + +def test_mod(): + ck = IntSetChecker() + x, y = te.var("x"), te.var("y") + tmod = tvm.tir.truncmod + ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) + ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) + ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9)) + + flm = tvm.te.floormod + ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9)) + ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5)) + ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5)) + ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 15)}, (0, 9)) + ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9)) + ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9)) + + fld = tvm.te.floordiv + z = te.var("z") + ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3)) + ck.verify( + flm(y, 8), + {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, + ( + z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8), + z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8), + ), + ) + ck1 = IntSetChecker() + ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2)) + ck1.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (x * 4, x * 4 + 3)) + + +def test_max_min(): + ck = IntSetChecker() + x, y = te.var("x"), te.var("y") + ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11)) + ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9)) + ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y))) + ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y))) + + +def test_select(): + ck = IntSetChecker() + # x, y = te.var("x"), te.var("y") + x = te.var("x") + ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11)) + + +def check_region_bound(expect_region, var_dom, mode, predicate=None): + """Helper to check region bound estimation. + + Parameters + ---------- + expect_region: dict + The keys are of form (begin, end) or PrimExpr as a single point. The values are + expected estimated region or region dict on different bindings. + + var_dom: dict + Map var to iteration domain range. + + mode: str + Specify "lowerbound", "upperbound" or else use strict bound estimation. + + predicate: PrimExpr + Extra predicate, defaults to True. + """ + if predicate is None: + predicate = tvm.tir.IntImm("bool", 1) + region = [] + expect = [] + for k, v in expect_region.items(): + if not isinstance(k, (tuple, list)): + k = (k, k + 1) + region.append(tvm.ir.Range.from_min_extent(k[0], Analyzer().simplify(k[1] - k[0]))) + expect.append(v) + if mode == "lowerbound": + result = tvm.arith.estimate_region_lower_bound(region=region, var_dom=var_dom, predicate=predicate) + elif mode == "upperbound": + result = tvm.arith.estimate_region_upper_bound(region=region, var_dom=var_dom, predicate=predicate) + else: + result = tvm.arith.estimate_region_strict_bound(region=region, var_dom=var_dom, predicate=predicate) + if result is None: + assert all([_ is None for _ in expect]) + return + assert len(result) == len(expect) + for intset, expect_desc in zip(result, expect): + if isinstance(expect_desc, dict): + # check range on different free var bindings + for binding in expect_desc: + analyzer = Analyzer() + for k, v in binding: + analyzer.bind(k, v) + expect_begin, expect_end = expect_desc[binding] + result_begin = analyzer.simplify(intset.min_value, 3) + result_end = analyzer.simplify(intset.max_value + 1, 3) + assert analyzer.can_prove_equal(result_begin - expect_begin, 0), f"{result_begin} vs {expect_begin}" + assert analyzer.can_prove_equal(result_end - expect_end, 0), f"{result_end} vs {expect_end}" + else: + # check range + expect_begin, expect_end = expect_desc + analyzer = Analyzer() + assert analyzer.can_prove_equal(intset.min_value - expect_begin, 0), f"{intset.min_value} vs {expect_begin}" + assert analyzer.can_prove_equal(intset.max_value - expect_end + 1, 0), f"{intset.max_value} vs {expect_end - 1}" + + +def test_region_bound_not_independent(): + # (i, i+2) and (i+2, i+4) are dependent, this the lowerbound is not available + i = tvm.tir.Var("i", "int32") + var_dom = { + i: tvm.ir.Range(begin=0, end=64), + } + check_region_bound({(i, i + 2): None, (i + 2, i + 4): None}, var_dom, mode="lowerbound") + check_region_bound({(i, i + 2): (0, 65), (i + 2, i + 4): (2, 67)}, var_dom, mode="upperbound") + + # when only a subset of access indices are affine + i, j, k = tvm.tir.Var("i", "int32"), tvm.tir.Var("j", "int32"), tvm.tir.Var("k", "int32") + var_dom = { + i: tvm.ir.Range(begin=0, end=16), + j: tvm.ir.Range(begin=0, end=16), + k: tvm.ir.Range(begin=0, end=16), + } + check_region_bound( + {i // 4: None, j * 4 + i % 4: None, tir.truncdiv(k, 2): None}, + var_dom, + predicate=j * 4 + i % 4 > 3, + mode="lowerbound", + ) + check_region_bound( + {i // 4: (0, 4), j * 4 + i % 4: (4, 64), tir.truncdiv(k, 2): (0, 8)}, + var_dom, + predicate=j * 4 + i % 4 > 3, + mode="upperbound", + ) + + +def test_region_bound_stride_too_wide(): + i = tvm.tir.Var("i", "int32") + var_dom = {i: tvm.ir.Range(begin=0, end=64)} + check_region_bound({(i * 4, i * 4 + 2): None}, var_dom, mode="lowerbound") + check_region_bound({(i * 4, i * 4 + 2): (0, 254)}, var_dom, mode="upperbound") + + +def test_region_bound_small_stride(): + i = tvm.tir.Var("i", "int32") + var_dom = { + i: tvm.ir.Range(begin=0, end=64), + } + check_region_bound({(i * 4, i * 4 + 8): (0, 260)}, var_dom, mode="lowerbound") + + +def test_region_lower_bound_split_predicate(): + x_o = tvm.tir.Var("xo", "int32") + x_i = tvm.tir.Var("xi", "int32") + x = x_o * 4 + x_i + var_dom = { + x_o: tvm.ir.Range(begin=0, end=16), + x_i: tvm.ir.Range(begin=0, end=4), + } + check_region_bound({(x * 4, x * 4 + 8): (0, 256)}, var_dom, predicate=x < 63, mode="lowerbound") + + check_region_bound( + {(x * 4, x * 4 + 8): (0, 256), (x * 3, x * 3 + 5): (0, 191)}, + var_dom, + predicate=x < 63, + mode="upperbound", + ) + + +def test_region_lower_bound_multiple_variables(): + div = tvm.tir.floordiv + mod = tvm.tir.floormod + x = tvm.tir.Var("x", "int32") + wid = tvm.tir.Var("wid", "int32") + i = div(x, 16) + j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4 + k = wid % 32 + var_dom = { + x: tvm.ir.Range(begin=0, end=32), + wid: tvm.ir.Range(begin=0, end=64), + } + check_region_bound({i: (0, 2), j: (0, 32), k: (0, 32)}, var_dom, mode="lowerbound") + + +def test_region_lower_bound_negative_scale(): + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + var_dom = { + i: tvm.ir.Range(begin=0, end=4), + j: tvm.ir.Range(begin=0, end=4), + } + check_region_bound({(1 - i, 5 - i): (-2, 5), (20 - j * 4, 36 - j * 4): (8, 36)}, var_dom, mode="lowerbound") + + +def test_region_lower_bound_for_non_perfect_tile(): + h1 = tvm.tir.Var("h1", "int32") + h2 = tvm.tir.Var("h2", "int32") + h3 = tvm.tir.Var("h3", "int32") + + # non-uniform tiling, single inner variable + var_dom = { + h2: tvm.ir.Range(begin=0, end=10), + } + check_region_bound( + { + h3 * 8 + h2: { + (): ( + tvm.tir.max(h3 * 8, 1), + tvm.tir.min(0, h3 * 8 - 214) + 224, + ), + ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) + ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) + ((h3, 27),): (h3 * 8, 224), # h3 > 26: region is [h3 * 8, 224) + } + }, + var_dom, + predicate=tvm.tir.all(h3 * 8 + h2 >= 1, h3 * 8 + h2 < 224), + mode="lowerbound", + ) + + # non-uniform tiling, two inner variables + var_dom = { + h1: tvm.ir.Range(begin=0, end=5), + h2: tvm.ir.Range(begin=0, end=2), + } + check_region_bound( + { + h3 * 8 + h2 * 5 + h1: { + (): ( + tvm.tir.max(h3 * 8, 1), + tvm.tir.min(0, h3 * 8 - 214) + 224, + ), + ((h3, 0),): (1, 10), + ((h3, 10),): (h3 * 8, h3 * 8 + 10), + ((h3, 27),): (h3 * 8, 224), + } + }, + var_dom, + predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h2 * 5 + h1 < 224), + mode="lowerbound", + ) + + # lowerbound should fail on incompatible predicates + check_region_bound( + {h3 * 8 + h2 * 5 + h1: None}, + var_dom, + predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224), + mode="lowerbound", + ) + check_region_bound( + {h3 * 8 + h2 * 5 + h1: (h3 * 8, h3 * 8 + 10)}, + var_dom, + predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224), + mode="upperbound", + ) + + +def test_region_lower_bound_unfusable(): + var_dom = { + tvm.tir.Var("i", "int32"): tvm.ir.Range(8), + tvm.tir.Var("j", "int32"): tvm.ir.Range(4), + } + i, j = var_dom + check_region_bound({(i + j) // 2: (0, 6)}, var_dom, mode="lowerbound") + + +def test_union_lower_bound(): + neg_inf = tvm.arith.int_set.neg_inf() + pos_inf = tvm.arith.int_set.pos_inf() + set_0 = tvm.arith.IntervalSet(min_value=neg_inf, max_value=0) + set_1 = tvm.arith.IntervalSet(min_value=1, max_value=pos_inf) + result = tvm.arith.int_set.union_lower_bound([set_0, set_1]) + assert result.min_value.same_as(neg_inf) + assert result.max_value.same_as(pos_inf) + set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf) + result = tvm.arith.int_set.union_lower_bound([set_0, set_1, set_2]) + assert result.min_value.same_as(neg_inf) + assert result.max_value.same_as(pos_inf) + + +def test_modular_set(): + ck = IntSetChecker() + x = tvm.te.var("x", dtype="int32") + y = tvm.te.var("y", dtype="int32") + expr = (x * 2048 + y * 16) % 7168 + ck.verify(expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tilelang/original/testing/python/arith/test_arith_iter_affine_map.py b/tilelang/original/testing/python/arith/test_arith_iter_affine_map.py new file mode 100644 index 0000000000000000000000000000000000000000..7a666f87d7ef7a822e99af27a42ad6825b8dd8a9 --- /dev/null +++ b/tilelang/original/testing/python/arith/test_arith_iter_affine_map.py @@ -0,0 +1,1292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tilelang import tvm +import tilelang.testing +from tvm.tir import floordiv, floormod +from tvm.script import tir as T + + +def ifuse(inputs, pred_extent=None): + """Fuse iterators""" + value, extent = 0, 1 + for i, ext in inputs: + value = value * ext + i + extent = extent * ext + return value, extent if pred_extent is None else pred_extent + + +def isplit(axis, factor): + """Split iterators""" + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + return [ + (fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)), + (flm(axis[0], factor), factor), + ] + + +def var_dom(iters): + """Get domains of iterators""" + return {var: tvm.ir.Range(0, ext) for var, ext in iters} + + +def convert_iter_expr(expr): + return tvm.arith.normalize_iter_map_to_expr(expr) + + +def assert_iter_sum_pattern(expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True): + keys = list(expect_dict.keys()) + res = tvm.arith.detect_iter_map( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + indices = res.indices + assert len(indices) == len(keys), res.errors + for i, input_iter in enumerate(keys): + spec = expect_dict[input_iter] + ( + extent, + base, + ) = spec[0:2] + scale = spec[2] if len(spec) > 2 else 1 + expect_iter = spec[3] if len(spec) > 3 else None + sum_expr = indices[i] + assert isinstance(sum_expr, tvm.arith.IterSumExpr) + if extent == 1: + assert len(sum_expr.args) == 0 + else: + assert len(sum_expr.args) == 1 + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) + tvm.testing.assert_prim_expr_equal(sum_expr.base, base) + if expect_iter is not None: + if not isinstance(expect_iter, tvm.arith.IterMapExpr): + sum_expr = convert_iter_expr(sum_expr) + tvm.ir.assert_structural_equal(sum_expr, expect_iter) + + +def assert_iter_map_simplify(expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True): + keys = list(expect_dict.keys()) + _imap = tvm.arith.detect_iter_map( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + res = tvm.arith.iter_map_simplify( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + for i, input_expr in enumerate(keys): + expected_expr = expect_dict[input_expr] + tvm.ir.assert_structural_equal(res[i], expected_expr) + + +def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"): + res = tvm.arith.detect_iter_map(list(iters), dom_map, predicate=predicate, check_level=check_level).indices + assert len(res) == 0 + + +def test_trivial(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + dom_map = var_dom([(x, 3), (y, 4), (z, 1)]) + + assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map) + assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map) + + # not independent + assert_iter_sum_failure([x, x, 3], dom_map) + + assert_iter_sum_pattern({x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True) + assert_iter_sum_pattern({x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False) + assert_iter_sum_failure([x, z], dom_map, check_level="bijective") + + +def test_fuse(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + c = tvm.tir.SizeVar("c", "int32") + c0 = tvm.tir.SizeVar("c0", "int32") + + assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)])) + + assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)])) + + # fuse with symbolic factor + assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)])) + + # duplication + assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) + assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)])) + + # factor mismatch + assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)])) + + # simple stride pattern + assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)])) + + # simple stride pattern with symbolic + assert_iter_sum_pattern({x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)])) + + +def test_split(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + c0 = tvm.tir.SizeVar("c0", "int32") + c1 = tvm.tir.SizeVar("c1", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)])) + + assert_iter_sum_pattern({fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)])) + + # simple symbolic bound + # TODO(tvm-team) improve symbolic divisible check to enable + # more complicated symbolic bound + assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)])) + + assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)])) + + assert_iter_sum_pattern( + { + fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2), + }, + var_dom([(x, 8)]), + ) + + assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) + + # domain of x is undefined + assert_iter_sum_pattern({fld(flm(x, 49) + y, 49): (1, fld(flm(x, 49) + y, 49))}, var_dom([(y, 1)])) + + +def test_compound(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) + z = ifuse([yo, xo, yi]) + + # reconstruct the pattern manually + mx = tvm.arith.IterMark(x, 10) + my = tvm.arith.IterMark(y, 9) + xoscale = 3 + yoscale = 6 + yiscale = 1 + mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale) + myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) + myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) + mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) + sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) + assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) + + +def test_compound_floormod_two_regression(): + x = tvm.tir.Var("x", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + # regression + # extent of 2 of negative scale cannot be normalized + assert_iter_sum_failure( + [fld(x, 2) * 2 - flm(x, 2) + 1], + dom_map=var_dom([(x, 8)]), + ) + + +def test_predicate(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + + # available constraints + # upper bound only + assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128) + + assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127) + + # lower bound only + assert_iter_sum_pattern({x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5) + + assert_iter_sum_pattern({x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6) + + # lower bound + upper bound + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128), + ) + + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), + ) + + assert_iter_sum_pattern( + {x * 64 + y * 4 + z: (16, 16)}, + var_dom([(x, 16), (y, 16), (z, 4)]), + predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, x * 16 + y >= 4), + ) + + # constraints on one fused iter + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + k = tvm.tir.Var("k", "int32") + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (88, 1)}, + var_dom([(i, 11), (j, 5), (k, 2)]), + predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k < 9), + ) + + # constraints on single var + assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10) + + # iterations are subparts of constraint, invalid case 1 + assert_iter_sum_failure( + [i, j, k], + var_dom([(i, 128), (j, 128), (k, 128)]), + predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100), + ) + + # iterations are subparts of constraint, invalid case 2 + assert_iter_sum_failure( + [i * 128 + j, k], + var_dom([(i, 128), (j, 128), (k, 128)]), + predicate=i * 16384 + j * 128 + k < 100, + ) + + # irrelevant predicate + assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24) + + # constraint on nested fused iters + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (22, 3)}, + var_dom([(i, 11), (j, 5), (k, 2)]), + predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k < 9, i * 8 + j * 2 + k >= 3, i * 8 + j * 2 + k < 25), + ) + + # duplicate constraint on one fused iter + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (66, 2)}, + var_dom([(i, 11), (j, 5), (k, 2)]), + predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k >= 2, j * 2 + k < 8, j * 2 + k < 9), + ) + + # duplicate constraint on nested fused iters + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (15, 3)}, + var_dom([(i, 11), (j, 5), (k, 2)]), + predicate=tvm.tir.all( + j * 2 + k >= 1, + j * 2 + k >= 2, + j * 2 + k < 8, + j * 2 + k < 9, + i * 6 + j * 2 + k >= 3, + i * 6 + j * 2 + k < 25, + i * 6 + j * 2 + k >= 1, + i * 6 + j * 2 + k < 18, + ), + ) + + # constraint on non-disjoint fused iters should fail + assert_iter_sum_failure( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + predicate=tvm.tir.all(j * 2 + k >= 2, i * 4 + j >= 0), + ) + + # constraints with different lower bound + assert_iter_sum_pattern( + { + (i * 16 + j) // 23 * 8 + (i * 16 + j) % 23 - 15: ( + 64, + 0, + 1, + (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)), + ) + }, + var_dom([(i, 12), (j, 16)]), + predicate=tvm.tir.And( + tvm.tir.And(i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23)), + tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23), + ), + ) + + # constraint on many disjoint fused iters, case 1 + # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) + # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) + # i1 * 60 in [60, 240), extent=180 (= scale of i0) + i0 = tvm.tir.Var("i0", "int32") + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + i5 = tvm.tir.Var("i5", "int32") + assert_iter_sum_pattern( + {i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)}, + var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), + predicate=tvm.tir.all(i1 >= 1, i2 * 2 + i3 >= 2, i4 * 6 + i5 >= 3), + ) + + # constraint on many disjoint fused iters, case 2 + assert_iter_sum_pattern( + {i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)}, + var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), + predicate=tvm.tir.all(i1 * 5 + i2 >= 3, i1 * 5 + i2 < 8, i3 * 4 + i4 >= 1, i3 * 4 + i4 < 10), + ) + + # constraint on split iters + assert_iter_sum_pattern( + {i % 16: (7, 3), i // 16: (8, 4)}, + var_dom([(i, 1024)]), + predicate=tvm.tir.all(i % 16 >= 3, i % 16 < 10, i // 16 >= 4, i // 16 < 12), + check_level="bijective", + ) + + # constraint on split iters, nested case 1 + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (7, 3)}, + var_dom([(i, 5), (j, 32)]), + predicate=tvm.tir.all((i * 32 + j) % 16 >= 3, (i * 32 + j) % 16 < 10), + ) + + # constraint on split iters, nested case 2 + assert_iter_sum_failure( + [ + (i * 32 + j) % 16, + ], + var_dom([(i, 5), (j, 32)]), + predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 32), + check_level="bijective", + ) + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (16, 0)}, + var_dom([(i, 5), (j, 32)]), + predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 32), + ) + assert_iter_sum_pattern( + {(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)}, + var_dom([(i, 5), (j, 32)]), + predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 64), + ) + + # non-standard form of predicate + assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y) + + # duplicate constraint + assert_iter_sum_pattern( + {x * 10 + y: (64, 0)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64), + ) + + # useless constraint + assert_iter_sum_pattern({x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140) + + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + assert_iter_sum_pattern( + {i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)}, + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( + tvm.tir.all( + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, + ) + ), + ) + + # wrong constraint + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( + tvm.tir.all( + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 7, + ) + ), + ) + + # incompatible constraint + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( + tvm.tir.all( + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, + i1 * 4 + i3 < 20, + ) + ), + ) + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( + tvm.tir.all( + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i1 * 4 + i3 < 20, + ) + ), + ) + + # zero iter + xo = tvm.tir.Var("xo", "int32") + xi = tvm.tir.Var("xi", "int32") + y = tvm.tir.Var("y", "int32") + assert_iter_sum_pattern( + {xo * 129 + xi: (128, 0), y: (128, 0)}, + var_dom([(xo, 1), (xi, 129), (y, 128)]), + predicate=xo * 129 + xi < 128, + ) + + # strided iteration predicate + assert_iter_sum_pattern( + {xo * 16 + xi * 4: (10, 0, 4)}, + var_dom([(xo, 3), (xi, 4)]), + predicate=xo * 4 + xi < 10, + ) + + +def convert_division(divisions): + if divisions is None or len(divisions) == 0: + return [] + res = [] + for division in divisions[:-1]: + res.append( + [ + tvm.arith.normalize_iter_map_to_expr(division[0].source), + tvm.arith.normalize_iter_map_to_expr(division[1].source), + ] + ) + res.append([divisions[-1][0].extent, divisions[-1][1].extent]) + return res + + +def create_iter(name, extent): + return tvm.tir.Var(name, "int32"), extent + + +def test_subspace_division(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + c = tvm.tir.SizeVar("c", "int32") + + # simple 1.1 + res = tvm.arith.subspace_divide([z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x]) + res = convert_division(res) + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) + tvm.ir.assert_structural_equal(res[0][1], x + c) + + # simple 1.2 + res = tvm.arith.subspace_divide([z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18) + res = convert_division(res) + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) + tvm.ir.assert_structural_equal(res[0][1], x + c) + tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) + tvm.ir.assert_structural_equal(res[1][1], T.bool(True)) + + # compound 1 + i0 = create_iter("i0", 4) + j0 = create_iter("j0", 8) + i3 = create_iter("i3", 2) + + i1, i2 = isplit(j0, 4) + k0 = ifuse([i0, i1]) + k1 = ifuse([i2, i3]) + + # compound 1.1 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]]) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][1], i3[0]) + + # assert_iter_sum_pattern + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices + assert len(res2) == 2 + + # compound 1.2 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]]) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], i0[0]) + tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices + assert len(res2) == 2 + + # compound 1.3 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]]) + res = convert_division(res) + assert len(res) == 0 + + # compound 1.4 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][1], i3[0]) + tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) + tvm.ir.assert_structural_equal(res[2][1], T.bool(True)) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices + assert len(res2) == 2 + + # compound 1.5 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], i0[0]) + tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) + tvm.ir.assert_structural_equal(res[2][0], T.bool(True)) + tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices + assert len(res2) == 2 + + # compound 1.6 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7)) + res = convert_division(res) + assert len(res) == 0 + + # compound 2 + j0 = create_iter("j0", 4) + l0 = create_iter("l0", 2) + l1 = create_iter("l1", 6) + j3 = create_iter("j3", 3) + + k0 = ifuse([l0, l1]) + i1, j2 = isplit(k0, 3) + j1, i1 = isplit(i1, 2) + i0 = ifuse([j0, j1]) + i2 = ifuse([j2, j3]) + + # compound 2.1 + res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]]) + res = convert_division(res) + assert len(res) == 4 + tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices + assert len(res1) == 3 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices + assert len(res2) == 3 + + # compound 2.2 + res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]]) + res = convert_division(res) + assert len(res) == 4 + tvm.ir.assert_structural_equal(res[0][0], j0[0]) + tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])).indices + assert len(res1) == 3 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices + assert len(res2) == 3 + + # compound 2.3 + res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]]) + res = convert_division(res) + assert len(res) == 0 + + # compound 2.4 + res = tvm.arith.subspace_divide( + [i0[0], i1[0], i2[0]], + var_dom([j0, l0, l1, j3]), + [l1[0], j3[0]], + tvm.tir.all(i0[0] < 7, i2[0] < 8), + ) + res = convert_division(res) + assert len(res) == 4 + tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) + tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) + tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices + assert len(res1) == 3 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices + assert len(res2) == 3 + + # compound 2.5 + res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8) + res = convert_division(res) + assert len(res) == 0 + + +def test_subspace_divide_trivial_iters(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + # z = tvm.tir.Var("z", "int32") + + # trivial 1.1 + res = tvm.arith.subspace_divide([x * 16 + y], var_dom([(x, 1), (y, 16)]), [y], simplify_trivial_iterators=False) + res = convert_division(res) + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], x) + tvm.ir.assert_structural_equal(res[0][1], y) + + # trivial 1.2 + res = tvm.arith.subspace_divide( + [x, y], + var_dom([(x, 1), (y, 1)]), + [y], + simplify_trivial_iterators=False, + ) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], x) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][1], y) + + +def test_complex(): + n0 = create_iter("n0", 2) + n1 = create_iter("n1", 4) + + m0 = ifuse([n0, n1], 6) + m1 = create_iter("m1", 3) + + l0 = create_iter("l0", 4) + l1 = create_iter("l1", 8) + l2 = ifuse([m0, m1], 16) + l3 = create_iter("l3", 32) + + k0, k4 = isplit(l0, 2) + k1, k5 = isplit(l1, 2) + k2, k6 = isplit(l2, 4) + k3, k7 = isplit(l3, 4) + + j0 = ifuse([k0, k1], 7) + j1 = ifuse([k2, k3]) + j2 = ifuse([k4, k5]) + j3 = ifuse([k6, k7], 15) + + i0 = ifuse([j0, j1], 200) + i1 = ifuse([j2, j3], 50) + + n0_mark = tvm.arith.IterMark(n0[0], n0[1]) + n1_mark = tvm.arith.IterMark(n1[0], n1[1]) + l0_mark = tvm.arith.IterMark(l0[0], l0[1]) + l1_mark = tvm.arith.IterMark(l1[0], l1[1]) + m1_mark = tvm.arith.IterMark(m1[0], m1[1]) + l3_mark = tvm.arith.IterMark(l3[0], l3[1]) + + m0_expr = tvm.arith.IterSumExpr( + [ + tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4), + tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1), + ], + 0, + ) + m0_mark = tvm.arith.IterMark(m0_expr, 6) + l2_expr = tvm.arith.IterSumExpr( + [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)], + 0, + ) + l2_mark = tvm.arith.IterMark(l2_expr, 16) + k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4) + k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1) + k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8) + k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1) + k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30) + k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15) + k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4) + k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1) + + j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0) + j0_mark = tvm.arith.IterMark(j0_expr, 7) + i0_expr = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0) + + j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0) + j3_mark = tvm.arith.IterMark(j3_expr, 15) + i1_expr = tvm.arith.IterSumExpr([k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0) + + i0_mark = tvm.arith.IterMark(i0_expr, i0[1]) + i1_mark = tvm.arith.IterMark(i1_expr, i1[1]) + + i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) + i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) + + assert_iter_sum_pattern( + {i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)}, + var_dom([l0, l1, n0, n1, m1, l3]), + predicate=tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), + ) + + # wrong constraint + assert_iter_sum_failure( + [i0[0], i1[0]], + var_dom([l0, l1, n0, n1, m1, l3]), + tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), + ) + + # subspace_division + res = tvm.arith.subspace_divide( + [i0[0], i1[0]], + var_dom([l0, l1, n0, n1, m1, l3]), + [n0[0], n1[0], m1[0], l3[0]], + tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), + ) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2)) + tvm.ir.assert_structural_equal(res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4)) + tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2))) + tvm.ir.assert_structural_equal(res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4))) + tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7) + tvm.ir.assert_structural_equal( + res[2][1], + tvm.tir.all( + n0[0] * 4 + n1[0] < 6, + (n0[0] * 4 + n1[0]) * 3 + m1[0] < 16, + floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15, + ), + ) + + assert_iter_sum_pattern({res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1]) + assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1])) + + +def test_normalize_iter_map_to_expr(): + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) + z = ifuse([yo, xo, yi]) + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)])) + + tvm.ir.assert_structural_equal( + tvm.arith.normalize_iter_map_to_expr(res.indices[0]), + fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3), + ) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5)) + + # iter mark wrap a complex expr + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1) + + +def test_inverse_affine_iter_map(): + analyzer = tvm.arith.Analyzer() + l0 = create_iter("l0", 64) + l1 = create_iter("l1", 64) + l2 = create_iter("l2", 64) + + # simple case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l0_1_l1_1_fused = ifuse([l0_1, l1_1]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])).indices + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 2 + l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16 + l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 + assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) + assert analyzer.can_prove_equal(res[l1[0]], l1_inverse) + + # compound case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l2_1, l2_2 = isplit(l2, 4) + l2_0, l2_1 = isplit(l2_1, 4) + + l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2])).indices + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 3 + l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16 + l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 + l2_inverse = floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2] + + assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) + assert analyzer.can_prove_equal(res[l1[0]], l1_inverse) + assert analyzer.can_prove_equal(res[l2[0]], l2_inverse) + + # diamond-shape DAG + l0_0, l0_1 = isplit(l0, 16) + l1 = ifuse([l0_1, l0_0]) + l1_0, l1_1 = isplit(l1, 8) + l2 = ifuse([l1_1, l1_0]) + + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 1 + l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8) + l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4) + + assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) + + +def test_inverse_affine_map_trivial_iter(): + analyzer = tvm.arith.Analyzer() + l0 = create_iter("l0", 64) + l1 = create_iter("l1", 64) + iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0, l1])).indices + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + # output_0 is expected to be constant and it is not included in the inverse map + assert len(res) == 2 + assert analyzer.can_prove_equal(res[l0[0]], outputs[1]) + assert analyzer.can_prove_equal(res[l1[0]], outputs[2]) + + +def test_free_variables(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + + # illegal iter if z is within dom + assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) + + # iter is valid if z is free, even there are linear forms of z + assert_iter_sum_pattern( + {z * 19 + y * 3 + x: (9, z * 19)}, + var_dom( + [ + (x, 3), + (y, 3), + ] + ), + ) + assert_iter_sum_pattern( + {z * z + y * 3 + x: (9, z * z)}, + var_dom( + [ + (x, 3), + (y, 3), + ] + ), + ) + + +class TestPadding: + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + positive_test_case = tvm.testing.parameter( + # left padding only, offset divisible + ({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"), + # left padding only, offset non-divisible + ({y: 176}, {fld(80 + y, 32): (6, 2, 1)}), + ({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}), + # right padding only, offset divisible + ({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}), + # right padding only, offset non-divisible + ({x: 26}, {fld(x, 15): (2, 0, 1)}), + ({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}), + # padding constants on both side + ({x: 45}, {fld(x + 71, 32): (2, 2, 1)}), + ({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}), + # padding for free iteration part + ({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}), + ({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}), + # multiple split with same mark offset, could + # be surjective on missing (padded // LCM) + ( + {x: 240}, + { + flm(x + 10, 3): (3, 0), + flm(fld(x + 10, 3), 4): (4, 0), + flm(fld(fld(x + 10, 3), 4), 5): (5, 0), + }, + ), + # different offsets on splits + ( + {x: 240}, + { + flm(x + 1, 3): (3, 0), + flm(fld(x + 10, 3) + 2, 4): (4, 0), + flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0), + }, + ), + ) + + negative_test_case = tvm.testing.parameter( + # left padding only, offset non-divisible + ({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}), + ({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}), + # right padding only, offset divisible + ({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}), + # multiple split with same mark offset, could + # be surjective on missing (padded // LCM) + ( + {x: 240}, + { + flm(x + 10, 3), + flm(fld(x + 10, 3), 4), + flm(fld(fld(x + 10, 3), 4), 5), + fld(fld(fld(x + 10, 3), 4), 5), + }, + ), + # original extent is smaller than the divident + # it is not surjective wrt to the region [0, 16) + ({x: 3}, {flm(x, 16)}), + # (x % c1) // c2 is not proved as surjective if c1 % c2 != 0 + ({x: 255}, {fld(flm(x, 255), 16)}), + ) + + def test_padding(self, positive_test_case): + iter_extent, mapped_iterators, *args = positive_test_case + check_level = args[0] if args else "surjective" + dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()} + assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level) + + def test_padding_error(self, negative_test_case): + iter_extent, mapped_iterators, *args = negative_test_case + check_level = args[0] if args else "surjective" + dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()} + assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level) + + +def test_overlapped_fuse(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + a = tvm.tir.Var("x", "int32") + b = tvm.tir.Var("y", "int32") + + # non-bijective fuse of two + assert_iter_sum_pattern( + { + x * 7 + y: (22, 0, 1), + }, + var_dom([(x, 3), (y, 8)]), + check_level="surjective", + ) + assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective") + + # non-bijective fuse of three + assert_iter_sum_pattern( + { + x * 18 + y * 7 + z: (40, 0, 1), + }, + var_dom([(x, 2), (y, 3), (z, 8)]), + check_level="surjective", + ) + assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective") + + # negative scale fusion is not allowed + assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective") + assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective") + + # with predicate + assert_iter_sum_pattern( + { + a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1), + }, + var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]), + predicate=tvm.tir.all(z < 4, x * 6 + y > 1, x * 6 + y < 10), + check_level="surjective", + ) + + # stride=1 kernel + assert_iter_sum_pattern({x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective") + + # do not allow both strided and overlapped + assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective") + + +def test_iter_map_simplify_symbolic_case(): + """Test itermap simplify""" + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + z = x * 32 + y + + n = tvm.tir.SizeVar("n", "int64") + + def simple_fuse0(x): + return (x // n) * n + x % n + + assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)])) + + assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) + + def fsymbolic_fuse0(x): + return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n + + assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) + + assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) + + def fsymbolic_fuse1(x): + return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n + + assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) + + assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) + + def fsymbolic_fuse2(i): + return (i // (n * n) * n + i % (n * n) // n) * n + i % n + + assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) + + +def test_iter_map_simplify_symbolic_predicate(): + """Test itermap simplify""" + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + + n = tvm.tir.SizeVar("n", "int64") + + def simple_fuse0(x): + return (x // n) * n + x % n + + z = x * 32 + y + assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16)) + + def fsymbolic_fuse2(i): + return (i // (n * n) * n + i % (n * n) // n) * n + i % n + + z = x * 64 + y + assert_iter_map_simplify( + {fsymbolic_fuse2(z): z}, + var_dom([(x, (n * n + 1) // 2), (y, 64)]), + predicate=(z < n * n * 32), + ) + + +def test_iter_map_simplify_symbolic_reshape(): + n = tvm.tir.Var("n", "int64") + fused = tvm.tir.Var("fused", "int64") + + ax0 = (fused // 4096) // n + ax1 = (fused // 4096) % n + ax2 = fused % 4096 + + rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096 + + assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)])) + + +def test_iter_map_simplify_unit_loop_order(): + """Test itermap simplify""" + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + z = tvm.tir.Var("z", "int64") + + # trivial iterators can be found at any when comparing via scale + # ensure order unchange + assert_iter_map_simplify({x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False) + + # Even with simplification, it should follow the original order + assert_iter_map_simplify( + {x + y + (z // 4) * 4 + z % 4: z + x + y}, + var_dom([(x, 1), (y, 1), (z, 32)]), + simplify_trivial_iterators=False, + ) + + assert_iter_map_simplify( + {y + 64 - x % 2 * 64: y + 64 - x % 2 * 64}, + var_dom([(x, 6), (y, 64)]), + simplify_trivial_iterators=False, + ) + + # When we have iterators that have same scale but one of them come + # with unit extent, we should prioritize unit extent + assert_iter_map_simplify( + {x // 128 + y + z: y + z}, + var_dom([(x, 128), (y, 128), (z, 1)]), + simplify_trivial_iterators=False, + ) + + +def assert_normalize_to_iter_sum(index, input_iters, args, base): + """Assert the result of arith.normalize_to_iter_sum is correct + + Parameters + ---------- + index : tvm.tir.PrimExpr + The index to be normalized + input_iters : Mapping[Var, Range] + The input iterators + args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]] + The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be + either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the + iterator normalized to PrimExpr and the second element is the scale. + base : tvm.tir.PrimExpr + The expected base + """ + res = tvm.arith.normalize_to_iter_sum(index, input_iters) + + assert isinstance(res, tvm.arith.IterSumExpr) + assert len(res.args) == len(args) + for split, item in zip(res.args, args): + if isinstance(item, tvm.arith.IterSplitExpr): + tvm.ir.assert_structural_equal(split, item) + continue + tvm.testing.assert_prim_expr_equal(split.scale, item[1]) + tvm.testing.assert_prim_expr_equal(tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]) + tvm.testing.assert_prim_expr_equal(res.base, base) + + +def test_normalize_to_iter_sum(): + x = tvm.tir.Var("x", "int64") + y = tvm.tir.Var("y", "int64") + z = tvm.tir.Var("z", "int64") + a = tvm.tir.Var("a", "int64") + n = tvm.tir.Var("n", "int64") + # flm = tvm.tir.floormod + + assert_normalize_to_iter_sum( + z + ((y + x * 4 + 2) * n) + 3, + var_dom([(x, 9), (y, 4), (z, 3)]), + [(x, n * 4), (y, n), (z, 1)], + 2 * n + 3, + ) + + # max cannot detected so it goes into base + assert_normalize_to_iter_sum( + tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3, + var_dom([(x, 9), (y, 4), (z, 3)]), + [(x, n * 4), (y, n)], + tvm.tir.max(z, a) + 2 * n + 3, + ) + + # order by symbolic prod + assert_normalize_to_iter_sum( + z + ((y * 4 * a + x * 4 + 2) * n) + 3, + var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), + [(y, a * n * 4), (x, n * 4), (z, 1)], + 2 * n + 3, + ) + + # order by cscale + assert_normalize_to_iter_sum( + z + 2 * y * 3 + 4 * x, + var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), + [(y, 6), (x, 4), (z, 1)], + 0, + ) + + # split pattern + assert_normalize_to_iter_sum( + z + 2 * y * 3 + 4 * (x // 2), + var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), + [(y, 6), (x // 2, 4), (z, 1)], + 0, + ) + + # non-divisible + assert_normalize_to_iter_sum( + x // 5, + var_dom([(x, 4096)]), + [ + tvm.arith.IterSplitExpr( + tvm.arith.IterMark(x, 4096), + lower_factor=tvm.tir.const(5, "int64"), + extent=tvm.tir.const(820, "int64"), + scale=tvm.tir.const(1, "int64"), + ) + ], + 0, + ) + + # iter simplify + assert_normalize_to_iter_sum( + z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4), + var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), + [(y, 6), (z, 2), (x, 1)], + 0, + ) + + +def test_detect_iter_map_with_bufferload_recursion(): + n = tvm.tir.Var("n", "int32") + m = tvm.tir.Var("m", "int32") + divisor = tvm.tir.Var("divisor", "int32") + + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + + buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen") + + indices = [(buffer[i] + j) // divisor] + iter_vars = { + i: tvm.ir.Range(tvm.tir.const(0, "int32"), n), + j: tvm.ir.Range(tvm.tir.const(0, "int32"), m), + } + + result = tvm.arith.detect_iter_map(indices, iter_vars) + assert len(result.indices) == 0 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/arith/test_arith_simplify.py b/tilelang/original/testing/python/arith/test_arith_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6cf6d3df74e0e6ddb15f463b96712f07e87b95 --- /dev/null +++ b/tilelang/original/testing/python/arith/test_arith_simplify.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tilelang import tvm +import tilelang.testing +from tvm import tir +import tvm.ir + + +def test_simplify_reshape_flattened_index(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + ana.bind(i0, tvm.ir.Range(0, 8)) + ana.bind(i1, tvm.ir.Range(0, 3)) + + i_flattened = i0 * 3 + i1 + tvm.ir.assert_structural_equal( + ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4), + i_flattened, + ) + + +dtype = tvm.testing.parameter( + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", +) + + +def test_can_prove_self_identity(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove(n == n) + + +def test_can_prove_self_equal_to_self(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove_equal(n, n) + + +def test_simplify_symbolic_comparison(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64") + outer = (n + 31) // 32 + ana.bind(i0, tvm.ir.Range(0, outer)) + ana.bind(i1, tvm.ir.Range(0, 32)) + PS = tvm.arith.ProofStrength + + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND) + + +def test_regression_simplify_inf_recursion(): + ana = tvm.arith.Analyzer() + cond = tir.Var("cond", "int32") + + res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype("int32") == 0 + # regression in a previous case + # try compare and int set recursive call can cause infinite loop + ana.rewrite_simplify(res) + + +def test_simplify_floor_mod_with_linear_offset(): + """ + Test that the floor_mod is simplified correctly when the offset is linear. + """ + ana = tvm.arith.Analyzer() + past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64") + expr1 = (past_decoder_sequence_length + 1) * 64 + divisor1 = (past_decoder_sequence_length + 1) * 32 + assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0) + divisor2 = 32 * (past_decoder_sequence_length + 1) + assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0) + + +def test_simplify_float_division(): + # Test for the discussion: + # https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615 + ana = tvm.arith.Analyzer() + x = tir.Var("x", "float32") + ry = x / 27 + # in old version, the division will be rewritten into x * T.float32(1 / 27) + sy = ana.rewrite_simplify(ry) + tvm.ir.assert_structural_equal(ry, sy) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/autotune/test_tilelang_autotune.py b/tilelang/original/testing/python/autotune/test_tilelang_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..53707ca3418ef848595beb9e4b04d5e3479c5387 --- /dev/null +++ b/tilelang/original/testing/python/autotune/test_tilelang_autotune.py @@ -0,0 +1,274 @@ +import itertools +import logging + +import tilelang.testing +import tilelang.language as T +from tilelang.autotuner import AutoTuner + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(M, N, K, with_roller=False): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + if with_roller: + from tilelang.carver.template import MatmulTemplate + from tilelang.carver.arch import CUDA + from tilelang.carver.roller.rasterization import NoRasterization + + arch = CUDA("cuda") + topk = 20 + + # Simple TIR Compute Expression + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + roller_hints = carve_template.recommend_hints(topk=topk) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = 0 + config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + for config in configs: + print(config) + else: + block_M = [64] + block_N = [64] + block_K = [32] + num_stages = [0, 1] + thread_num = [128] + enable_rasterization = [False] + + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def matmul(M, N, K, with_roller): + """ + Create an autotuned matrix multiplication kernel for matrices of shape: + - A: (M, K) + - B: (N, K) + - C: (M, N) + + Parameters + ---------- + M : int + The dimension M of the matrix multiplication. + N : int + The dimension N of the matrix multiplication. + K : int + The dimension K of the matrix multiplication. + + Returns + ------- + (best_latency, best_config, ref_latency) + best_latency : float + The best latency found among the tuned configurations. + best_config : dict + The parameter configuration that yielded best_latency. + ref_latency : float + The baseline latency of the reference program (for computing speedup). + """ + + # Decorate the kernel with autotune & jit, specifying: + # - Tuning config list + # - Profiling keys + # - Warmup and repetition counts for better measurement + # - A reference program for correctness verification + # - The "tvm" profiler backend + # - HIP as the compilation target (modify as needed for your hardware) + + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, + ): + """ + The actual kernel to compute C = A @ B^T. + + Parameters + ---------- + block_M : int + Block size in M dimension. + block_N : int + Block size in N dimension. + block_K : int + Block size in K dimension. + num_stages : int + Number of pipelined stages (for asynchronous load). + thread_num : int + Number of threads to use per block. + enable_rasteration : bool + Whether to enable rasterization (swizzling) optimization. + k_pack : int + K dimension packing factor to improve memory coalescing. + + Returns + ------- + Function + A TVM Tensor Language function (T.prim_func) that computes matmul. + """ + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy( + A[by * block_M, k * block_K], + A_shared, + ) + # Load a sub-block of B from global memory into B_shared + T.copy( + B[bx * block_N, k * block_K], + B_shared, + ) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( + out_idx=[-1], + target="auto", + ) + .set_profile_args( + ref_prog=ref_program, + ) + ) + return autotuner.run(warmup=3, rep=20) + + +def test_autotune_get_configs(): + get_configs(1024, 1024, 1024, with_roller=True) + get_configs(1024, 1024, 1024, with_roller=False) + + +def test_autotune_matmul(): + matmul(1024, 1024, 1024, with_roller=True) + matmul(1024, 1024, 1024, with_roller=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/autotune/test_tilelang_autotune_with_inputs.py b/tilelang/original/testing/python/autotune/test_tilelang_autotune_with_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..4edea0b8839342e97e53eeb0c53558ab8ddd1ba6 --- /dev/null +++ b/tilelang/original/testing/python/autotune/test_tilelang_autotune_with_inputs.py @@ -0,0 +1,144 @@ +import itertools +import logging +import tilelang +import tilelang.testing +from tilelang.autotuner import set_autotune_inputs +import tilelang.language as T + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(): + iter_params = dict(block_M=[64], block_N=[64], block_K=[32], num_stages=[0, 1], thread_num=[128], enable_rasterization=[False]) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False): + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy( + A[by * block_M, k * block_K], + A_shared, + ) + # Load a sub-block of B from global memory into B_shared + T.copy( + B[bx * block_N, k * block_K], + B_shared, + ) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_autotune(M, N, K, M_value=None, N_value=None, K_value=None): + import torch + + def _resolve(dim, provided, name): + if isinstance(dim, T.Var): + if provided is None: + raise ValueError(f"Dynamic dimension {name} requires a concrete value.") + return provided + return dim + + actual_M = _resolve(M, M_value, "M") + actual_N = _resolve(N, N_value, "N") + actual_K = _resolve(K, K_value, "K") + + a = torch.randn(actual_M, actual_K, dtype=torch.float16).cuda() + b = torch.randn(actual_N, actual_K, dtype=torch.float16).cuda() + + with set_autotune_inputs([a, b]): + kernel = matmul(M, N, K) + + c = kernel(a, b) + + ref_c = ref_program(a, b) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def test_autotune_matmul(): + """ + Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem. + + This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel, + executes it, and asserts the result matches a reference CPU implementation within tolerances. + """ + run_autotune(1024, 1024, 1024) + + +def test_autotune_matmul_symbolic_m(): + run_autotune(T.symbolic("m"), 1024, 1024, M_value=1024) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py b/tilelang/original/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py new file mode 100644 index 0000000000000000000000000000000000000000..67d20b89790afc81e2b2b5a71795a408e700cdaf --- /dev/null +++ b/tilelang/original/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py @@ -0,0 +1,72 @@ +import tilelang.testing +from tilelang.carver.arch.driver.cuda_driver import ( + get_cuda_device_properties, + get_device_name, + get_shared_memory_per_block, + get_device_attribute, + get_max_dynamic_shared_size_bytes, + get_persisting_l2_cache_max_size, + get_num_sms, + get_registers_per_block, +) +import torch + + +class _cudaDeviceAttrNames: + r""" + This struct carries all properties that are of int32_t. + refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd + """ + + cudaDevAttrMaxThreadsPerBlock: int = 1 + cudaDevAttrMaxSharedMemoryPerBlock: int = 8 + cudaDevAttrMaxRegistersPerBlock: int = 12 + cudaDevAttrMultiProcessorCount: int = 16 + cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 + cudaDevAttrMaxPersistingL2CacheSize: int = 108 + + +def test_driver_get_device_properties(): + prop = get_cuda_device_properties() + assert prop is not None, "Failed to get CUDA device properties" + assert isinstance(prop, torch.cuda._CudaDeviceProperties), "Returned object is not of type _CudaDeviceProperties" + + +def test_device_get_device_name(): + tl_device_name = get_device_name() + th_device_name = torch.cuda.get_device_name() + assert tl_device_name == th_device_name, "Device names do not match" + + +def test_device_get_shared_memory_per_block(): + tl_smem = get_shared_memory_per_block() + driver_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerBlock) + assert tl_smem == driver_smem, "Shared memory per block values do not match" + + +def test_device_get_persisting_l2_cache_size(): + tl_cache_size = get_persisting_l2_cache_max_size() + driver_cache_size = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) + assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match" + + +def test_device_get_num_sms(): + tl_num_sms = get_num_sms() + driver_num_sms = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMultiProcessorCount) + assert tl_num_sms == driver_num_sms, "Number of SMs do not match" + + +def test_device_get_registers_per_block(): + tl_regs_per_block = get_registers_per_block() + driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock) + assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" + + +def test_device_get_max_dynamic_shared_size_bytes(): + tl_dynamic_smem = get_max_dynamic_shared_size_bytes() + driver_dynamic_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) + assert tl_dynamic_smem == driver_dynamic_smem, "Max dynamic shared size bytes values do not match" + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/carver/test_tilelang_carver_generate_hints.py b/tilelang/original/testing/python/carver/test_tilelang_carver_generate_hints.py new file mode 100644 index 0000000000000000000000000000000000000000..ea674f7c743a0ba5e21bc4bafdd169fec2c1f85c --- /dev/null +++ b/tilelang/original/testing/python/carver/test_tilelang_carver_generate_hints.py @@ -0,0 +1,100 @@ +import tilelang.testing +from tilelang import carver +from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge +from tilelang.carver.arch import auto_infer_current_arch +from tvm import te +from tilelang.language import dtypes as T + + +def run_general_matmul_emit_configs(M, N, K, topk: int = 20): + arch = auto_infer_current_arch() + + def gemm(M, N, K): + A = te.placeholder((M, K), name="A", dtype=T.float16) + B = te.placeholder((N, K), name="B", dtype=T.float16) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C") + + return A, B, C + + arg1 = gemm(M, N, K) + args = arg1 + + func = te.create_prim_func(args) + + tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) + print(tags) + policy = carver.TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags, name="matmul_0") + + hints = policy.emit_config(topk=topk) + + for hint in hints: + print(hint) + + assert len(hints) > 0, "Hints length is zero" + + prim_func_node = PrimFuncNode(tensorized_func, name="matmul_1") + output_nodes = [OutputNode(prim_func_node)] + policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags) + + hints = policy.emit_config(topk=10) + + for config in hints: + print(config) + + assert len(hints) > 0, "Hints length is zero" + + +def test_general_matmul_emit_configs(): + run_general_matmul_emit_configs(128, 128, 128) + + +def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): + arch = auto_infer_current_arch() + + def gemm(M, N, K): + A = te.placeholder((M, K), name="A", dtype=T.float16) + B = te.placeholder((N, K), name="B", dtype=T.float16) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C") + + return A, B, C + + arg1 = gemm(M, N, K) + args = arg1 + + func = te.create_prim_func(args) + + tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) + print(tags) + + node_0 = PrimFuncNode(tensorized_func, name="matmul_0") + node_1 = PrimFuncNode(tensorized_func, name="matmul_1") + + edge = Edge(node_0, node_1, 0, 0) + node_0._out_edges.append(edge) + node_1.set_inputs(0, edge) + + output_nodes = [OutputNode(node_1)] + policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags) + + hints = policy.emit_config(topk=topk) + + for config in hints: + print(config) + + assert len(hints) > 0, "Hints length is zero" + + +def test_general_matmul_matmul_emit_configs(): + run_general_matmul_matmul_emit_configs(128, 128, 128) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/carver/test_tilelang_carver_recommend_hints.py b/tilelang/original/testing/python/carver/test_tilelang_carver_recommend_hints.py new file mode 100644 index 0000000000000000000000000000000000000000..3a060f5323d75f4727ae25aefe9a40608cca27cb --- /dev/null +++ b/tilelang/original/testing/python/carver/test_tilelang_carver_recommend_hints.py @@ -0,0 +1,142 @@ +import tilelang.testing +from tilelang import carver +from tilelang.language import dtypes as T +from tilelang.carver.arch import auto_infer_current_arch +from typing import List + + +def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20): + arch = auto_infer_current_arch() + carve_template = carver.GeneralReductionTemplate( + structure=structure, + shape=shape, + dtype=dtype, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + hints = carve_template.recommend_hints(topk=topk) + assert len(hints) > 0, "Hints length is zero" + + +def test_general_reduction_recommend_hints(): + run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], T.float16) + run_general_reduction_recommend_hints("SS", [1024, 1024], T.float16) + run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], T.float16) + + +def run_elementwise_recommend_hints(shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20): + arch = auto_infer_current_arch() + carve_template = carver.ElementwiseTemplate( + shape=shape, + dtype=dtype, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + hints = carve_template.recommend_hints(topk=topk) + assert len(hints) > 0, "Hints length is not topk" + + +def test_elementwise_recommend_hints(): + run_elementwise_recommend_hints([1024, 1024], T.float16) + run_elementwise_recommend_hints([1024], T.float16) + run_elementwise_recommend_hints([1024, 1024, 1024], T.float16) + + +def run_matmul_recommend_hints( + M: int = 1024, + N: int = 1024, + K: int = 1024, + in_dtype: T.dtype = T.float16, + out_dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float16, +): + arch = auto_infer_current_arch() + carve_template = carver.MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + hints = carve_template.recommend_hints(topk=20) + assert len(hints) > 0, "Hints length is not 20" + + +def test_matmul_recommend_hints(): + run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float16, T.float16) + run_matmul_recommend_hints(1024, 1024, 1024, T.int8, T.int32, T.int32) + run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float32, T.float16) + + +def run_gemv_recommend_hints( + N: int = 1024, K: int = 1024, in_dtype: T.dtype = T.float16, out_dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float16 +): + arch = auto_infer_current_arch() + carve_template = carver.GEMVTemplate( + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + hints = carve_template.recommend_hints(topk=20) + assert len(hints) > 0, "Hints length is not 20" + + +def test_gemv_recommend_hints(): + run_gemv_recommend_hints(1024, 1024, T.float16, T.float16, T.float16) + run_gemv_recommend_hints(1024, 1024, T.int8, T.int32, T.int32) + run_gemv_recommend_hints(1024, 1024, T.float16, T.float32, T.float16) + + +def run_fmha_recommend_hints( + batch_size: int = 4, + num_heads: int = 32, + seq_length: int = 512, + seq_kv_length: int = 512, + head_dim: int = 128, + in_dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float16, + out_dtype: T.dtype = T.float16, +): + arch = auto_infer_current_arch() + carve_template = carver.FlashAttentionTemplate( + batch_size=batch_size, + num_heads=num_heads, + seq_length=seq_length, + seq_kv_length=seq_kv_length, + head_dim=head_dim, + in_dtype=in_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + hints = carve_template.recommend_hints(topk=20) + for hint in hints: + print(hint) + assert len(hints) > 0, "Hints length should be greater than 0" + + +def test_fmha_recommend_hints(): + run_fmha_recommend_hints(4, 32, 512, 512, 128, T.float16, T.float16, T.float16) + run_fmha_recommend_hints(4, 32, 512, 512, 128, T.int8, T.int32, T.int32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/components/test_cuda_restrict_codegen.py b/tilelang/original/testing/python/components/test_cuda_restrict_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..bff8b3b192be844e05a6c2f9a905bd55c251561b --- /dev/null +++ b/tilelang/original/testing/python/components/test_cuda_restrict_codegen.py @@ -0,0 +1,48 @@ +import tilelang +import tilelang.language as T +import tilelang.testing + + +def _get_sig_line(code: str) -> str: + # Find the kernel signature line in generated CUDA code + for line in code.splitlines(): + line = line.strip() + if line.startswith('extern "C" __global__ void'): + return line + raise AssertionError("Kernel signature not found in generated code") + + +@tilelang.testing.requires_cuda +def test_cuda_restrict_default_has_restrict(): + N = 128 + + @T.prim_func + def kernel(x: T.Tensor((N,), T.float32), y: T.Tensor((N,), T.float32)): + with T.Kernel(N, threads=32) as pid: + y[pid] = x[pid] + 1.0 + + artifact = tilelang.lower(kernel, target="cuda") + sig = _get_sig_line(artifact.kernel_source) + # By default, kNoAlias is set and both pointers are restrict-qualified + assert "__restrict__" in sig + + +@tilelang.testing.requires_cuda +def test_cuda_restrict_annotation_removes_restrict(): + N = 128 + + @T.prim_func + def kernel_body_annot(x: T.Tensor((N,), T.float32), y: T.Tensor((N,), T.float32)): + # Explicitly mark buffers that may alias as non-restrict + with T.Kernel(N, threads=32) as pid: + T.annotate_restrict_buffers(x, y) + y[pid] = x[pid] + 1.0 + + art1 = tilelang.lower(kernel_body_annot, target="cuda") + sig1 = _get_sig_line(art1.kernel_source) + # No parameter should be emitted with __restrict__ + assert "__restrict__" not in sig1 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/components/test_storage_rewrite_detect_inplace.py b/tilelang/original/testing/python/components/test_storage_rewrite_detect_inplace.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4f4e5f3df8ccf79d973eb6bc62af757a79b609 --- /dev/null +++ b/tilelang/original/testing/python/components/test_storage_rewrite_detect_inplace.py @@ -0,0 +1,62 @@ +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit +def _compile_kernel_without_inplace(): + num_tokens = T.symbolic("num_tokens") + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]): + with T.Kernel(num_tokens, threads=32) as pid: + read = T.alloc_var(T.int) + read = x[pid] + + write = T.alloc_var(T.int) + write = read * 2 + x[pid] = write + + return buggy_kernel + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True, + }, +) +def _compile_kernel_with_inplace(): + num_tokens = T.symbolic("num_tokens") + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]): + with T.Kernel(num_tokens, threads=32) as pid: + read = T.alloc_var(T.int) + read = x[pid] + + write = T.alloc_var(T.int) + write = read * 2 + x[pid] = write + + return buggy_kernel + + +def _get_device_kernel_script(detect_inplace: bool) -> str: + if detect_inplace: + kernel = _compile_kernel_with_inplace() + else: + kernel = _compile_kernel_without_inplace() + source = kernel.get_kernel_source() + return source + + +def test_storage_rewrite_detect_inplace_toggle(): + script_off = _get_device_kernel_script(detect_inplace=False) + script_on = _get_device_kernel_script(detect_inplace=True) + + assert script_off.count("read = (read * 2);") == 0 + assert script_on.count("read = (read * 2);") > 0 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/components/test_tilelang_env.py b/tilelang/original/testing/python/components/test_tilelang_env.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc7679437ae0e18834883e205ccb81203a994bc --- /dev/null +++ b/tilelang/original/testing/python/components/test_tilelang_env.py @@ -0,0 +1,17 @@ +import tilelang +import os + + +def test_env_var(): + # test default value + assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1" + # test forced value + os.environ["TILELANG_PRINT_ON_COMPILATION"] = "0" + assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "0" + # test forced value with class method + tilelang.env.TILELANG_PRINT_ON_COMPILATION = "1" + assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1" + + +if __name__ == "__main__": + test_env_var() diff --git a/tilelang/original/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py b/tilelang/original/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py new file mode 100644 index 0000000000000000000000000000000000000000..d599e581ac24c30afe8831a7ca2b114991626228 --- /dev/null +++ b/tilelang/original/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py @@ -0,0 +1,140 @@ +import tilelang.testing +from tilelang import language as T +import torch + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, + disable_warp_specialized=False, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized, + }, + ) + profiler = kernel.get_profiler() + + def ref_program(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + disable_warp_specialized=False, + ) + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + disable_warp_specialized=True, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/cpu/test_tilelang_cpu_gemm.py b/tilelang/original/testing/python/cpu/test_tilelang_cpu_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..55646622e3d22057416278ae5215f2626ccdd287 --- /dev/null +++ b/tilelang/original/testing/python/cpu/test_tilelang_cpu_gemm.py @@ -0,0 +1,117 @@ +import tilelang +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +import torch + + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + num_stages = 0 + + @T.prim_func + def matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): + A_local = T.alloc_local((block_M, block_K), dtype) + B_local = T.alloc_local((block_K, block_N), dtype) + C_local = T.alloc_local((block_M, block_N), accum_dtype) + + T.clear(C_local) + + # Apply layout optimizations or define your own layout + # (Optional). + # T.annotate_layout( + # { + # A_local: make_swizzle_layout(A_local), + # B_local: make_swizzle_layout(B_local), + # } + # ) + + for ko in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_local) + + # Or Copy with Parallel + for k, j in T.Parallel(block_K, block_N): + B_local[k, j] = B[ko * block_K + k, by * block_N + j] + + for i, j, k in T.grid(block_M, block_N, block_K): + C_local[i, j] += A_local[i, k] * B_local[k, j] + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul + + +def assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32): + func = matmul(M, N, K, block_M, block_N, block_K) + + with tvm.target.Target("c"): + artifact = tilelang.lower(func) + + code = artifact.kernel_source + + assert code is not None, "Code generation failed" + + +def test_matmul_codegen(): + assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32) + + +def test_matmul_compile(): + def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + # a simple kernel just for jit test + @T.prim_func + def matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): + A_local = T.alloc_local((block_M, block_K), dtype) + B_local = T.alloc_local((block_K, block_N), dtype) + C_local = T.alloc_local((block_M, block_N), accum_dtype) + + for p in T.serial(block_M): + for w in T.serial(block_N): + C_local[p, w] = 0 + for ko in T.serial(K // block_K): + for i in T.serial(block_M): + for k in T.serial(block_K): + A_local[i, k] = A[by * block_M + i, ko * block_K + k] + + for k in T.serial(block_K): + for j in T.serial(block_N): + B_local[k, j] = B[ko * block_K + k, bx * block_N + j] + + for i in T.serial(block_M): + for j in T.serial(block_N): + for k in T.serial(block_K): + C_local[i, j] += A_local[i, k] * B_local[k, j] + + for i in T.serial(block_M): + for j in T.serial(block_N): + C[by * block_M + i, bx * block_N + j] = C_local[i, j] + + return matmul + + M, N, K = 1024, 512, 512 + block_M, block_N, block_K = M // 4, N // 4, K // 4 + cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K) + with tvm.target.Target("c"): + complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes") + + in_dtype = T.float16 + A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)) + B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)) + + C = complied_fun(A, B) + C_torch = torch.matmul(A, B) + + tilelang.testing.torch_assert_close(C, C_torch, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/dcu/__pycache__/test_tilelang_gemm_mmac_intrinsic.cpython-310-pytest-9.0.2.pyc b/tilelang/original/testing/python/dcu/__pycache__/test_tilelang_gemm_mmac_intrinsic.cpython-310-pytest-9.0.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3939ec7b500949e7298212f6fe4e29ebfdb7773c Binary files /dev/null and b/tilelang/original/testing/python/dcu/__pycache__/test_tilelang_gemm_mmac_intrinsic.cpython-310-pytest-9.0.2.pyc differ diff --git a/tilelang/original/testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py b/tilelang/original/testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py new file mode 100644 index 0000000000000000000000000000000000000000..abf43d5ec17b93ff1c14233013802db66f374e6b --- /dev/null +++ b/tilelang/original/testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py @@ -0,0 +1,248 @@ +import torch +import tilelang.testing +from tilelang import tvm as tvm +from tvm import DataType +import tilelang.language as T + +# from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mmac_macro_generator import ( + MatrixCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func + +tilelang.testing.set_random_seed(0) +tilelang.disable_cache() + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed=False, + b_transposed=True, + k_pack=1, +): + assert in_dtype in [ + "float16", + "bfloat16", + "int8", + ], "Currently only float16, bfloat16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if in_dtype in {"float8_e4m3fnuz", "int8"}: + micro_size_k = 32 + + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + + chunk = 32 * k_pack + + shared_scope = "shared" + # cache_write_shared = False + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size + local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mmac_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + k_pack=k_pack, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=0): + # Load A into shared memory + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Load B into shared memory + if b_transposed: + T.copy(B[bx * block_N, ko * block_K], B_shared) + else: + T.copy(B[ko * block_K, bx * block_N], B_shared) + + for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): + # Load A into fragment + mmac_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mmac_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mmac_emitter.mmac(A_local, B_local, C_local) + + # Perform STMatrix + mmac_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + j // micro_size_y, + i // micro_size_x, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack) + print(matmul) + kernel = tilelang.compile(matmul) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + if in_dtype == "int8": + A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + else: + A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + kernel(A, B, C) + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler() + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + if a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + elif a_transposed and not b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + elif not a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + else: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_rocm +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2) + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) + # assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") + # assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") + # assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) + # assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) + # assert_tl_matmul_correctness( + # 128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/debug/test_device_assert.py b/tilelang/original/testing/python/debug/test_device_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..210b8966d7b4acced3a1adc62725a3e579e14900 --- /dev/null +++ b/tilelang/original/testing/python/debug/test_device_assert.py @@ -0,0 +1,34 @@ +# type: ignore +import tilelang +import tilelang.testing +import tilelang.language as T + + +# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI +# Please run manually when you want to verify that device_assert actually traps on GPU. +def _manual_device_assert_triggered(): + @T.prim_func + def program(): + with T.Kernel(threads=128): + tid = T.get_thread_binding() + T.device_assert(tid > 0, "Assertion Trigger !") + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +def test_device_assert_no_trigger(): + @T.prim_func + def program(): + with T.Kernel(threads=128): + tid = T.get_thread_binding() + T.device_assert(tid == tid) + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +if __name__ == "__main__": + _manual_device_assert_triggered() diff --git a/tilelang/original/testing/python/debug/test_tilelang_debug_print.py b/tilelang/original/testing/python/debug/test_tilelang_debug_print.py new file mode 100644 index 0000000000000000000000000000000000000000..3483cffc0eb7aeb99669a716d83dd1610a66f27f --- /dev/null +++ b/tilelang/original/testing/python/debug/test_tilelang_debug_print.py @@ -0,0 +1,119 @@ +# type: ignore + +import tilelang +import tilelang.testing +import tilelang.language as T + + +def debug_print_buffer(M=16, N=16, dtype=T.float16): + @T.prim_func + def program(Q: T.Tensor((M, N), dtype)): + with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): + shared_buf = T.alloc_shared([M, N], dtype) + T.print(shared_buf) + + jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +def test_debug_print_buffer(): + debug_print_buffer(dtype=T.bool) + debug_print_buffer(dtype=T.int8) + debug_print_buffer(dtype=T.int16) + debug_print_buffer(dtype=T.int32) + debug_print_buffer(dtype=T.int64) + debug_print_buffer(dtype=T.uint8) + debug_print_buffer(dtype=T.uint16) + debug_print_buffer(dtype=T.uint32) + debug_print_buffer(dtype=T.uint64) + debug_print_buffer(dtype=T.float16) + debug_print_buffer(dtype=T.float32) + debug_print_buffer(dtype=T.float64) + debug_print_buffer(dtype=T.bfloat16) + debug_print_buffer(dtype=T.float8_e4m3fn) + debug_print_buffer(dtype=T.float8_e4m3fn) + debug_print_buffer(dtype=T.float8_e4m3fnuz) + debug_print_buffer(dtype=T.float8_e5m2) + debug_print_buffer(dtype=T.float8_e5m2fnuz) + + +def debug_print_buffer_conditional(M=16, N=16): + dtype = T.float16 + + @T.prim_func + def program(Q: T.Tensor((M, N), dtype)): + with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): + shared_buf = T.alloc_shared([M, N], dtype) + + if bx == 0 and by == 0 and bz == 0: + T.print(shared_buf) + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +def test_debug_print_buffer_conditional(): + debug_print_buffer_conditional(16, 16) + + +def debug_print_value_conditional(M=16, N=16): + dtype = T.float16 + + @T.prim_func + def program(Q: T.Tensor((M, N), dtype)): + with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): + tid = T.get_thread_binding() + if tid == 0: + T.print(bx + by + bz) + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +def test_debug_print_value_conditional(): + debug_print_value_conditional(16, 16) + + +def debug_print_register_files(M=16, N=16): + dtype = T.float16 + + @T.prim_func + def program(Q: T.Tensor((M, N), dtype)): + with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): + register_buf = T.alloc_fragment([M, N], dtype) + for i, j in T.Parallel(M, N): + T.print(register_buf[i, j]) + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +def test_debug_print_register_files(): + debug_print_register_files(16, 16) + + +def debug_print_msg(M=16, N=16): + dtype = T.float16 + + @T.prim_func + def program(Q: T.Tensor((M, N), dtype)): + with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): + tid = T.get_thread_binding() + if tid == 0: + T.print(bx + by + bz, msg="hello world") + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + + +def test_debug_print_msg(): + debug_print_msg(16, 16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/fastmath/test_mathops_fastmath.py b/tilelang/original/testing/python/fastmath/test_mathops_fastmath.py new file mode 100644 index 0000000000000000000000000000000000000000..e181eb4dffe973f13aae267c6328d4a2a926bb05 --- /dev/null +++ b/tilelang/original/testing/python/fastmath/test_mathops_fastmath.py @@ -0,0 +1,319 @@ +import pytest +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import re + + +def get_mathop_lines(source, mathop_name): + """Extract lines containing the mathop from CUDA source for debugging""" + lines = source.split("\n") + relevant_lines = [] + for i, line in enumerate(lines): + if mathop_name in line and ("(" in line): + # Include some context + start = max(0, i - 1) + end = min(len(lines), i + 2) + relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) + relevant_lines.append("---") + return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output + + +def check_fastmath_usage(source, mathop_name, expect_fastmath=False): + """Check source for fastmath/non-fastmath versions""" + fastmath_pattern = rf"__({mathop_name}f?)\b" + non_fastmath_pattern = rf"(? 0: + print(f"Fastmath calls found: {fastmath_matches}") + if len(non_fastmath_matches) > 0: + print(f"Non-fastmath calls found: {non_fastmath_matches}") + print(f"Source preview for {mathop_name}:") + print(get_mathop_lines(source, mathop_name)) + + if expect_fastmath: + assert len(fastmath_matches) > 0, "Expected fastmath calls but found none" + print(f"✓ {mathop_name} correctly uses fastmath versions") + else: + assert len(fastmath_matches) == 0, f"Found unexpected fastmath calls: {fastmath_matches}" + assert len(non_fastmath_matches) > 0, f"No {mathop_name} calls found" + print(f"✓ {mathop_name} correctly uses non-fastmath versions") + + +def check_non_fastmath_usage(source, mathop_name): + """Check that source uses non-fastmath versions (no __ prefix)""" + check_fastmath_usage(source, mathop_name, expect_fastmath=False) + + +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test single-argument mathops. + T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) + + # Test with FAST_MATH disabled + kernel_no_fastmath = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + source_no_fastmath = kernel_no_fastmath.get_kernel_source() + + print(f"\n=== Testing {mathop_name} ===") + print("FAST_MATH=False:") + + # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) + check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) + + print(f"✓ {mathop_name} compilation and execution test passed") + + +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test two-argument mathops to ensure they generate non-fastmath CUDA code. + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j] + ) + + # Test with FAST_MATH disabled + kernel_no_fastmath = tilelang.compile( + main, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + # Test with FAST_MATH enabled + kernel_fastmath = tilelang.compile( + main, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + + source_no_fastmath = kernel_no_fastmath.get_kernel_source() + source_fastmath = kernel_fastmath.get_kernel_source() + + print(f"\n=== Testing {mathop_name} (two args) ===") + print("FAST_MATH=False:") + check_non_fastmath_usage(source_no_fastmath, mathop_name) + + print("FAST_MATH=True:") + check_non_fastmath_usage(source_fastmath, mathop_name) + + # Test numerical correctness + torch_dtype = dtype.as_torch() + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + b = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + # Ensure positive values for functions that need them + if mathop_name == "pow": + a = torch.abs(a) + 0.1 + b = torch.clamp(b, -3, 3) # Limit exponent range + elif mathop_name == "fmod": + b = torch.abs(b) + 0.1 # Avoid division by zero + + c_no_fastmath = kernel_no_fastmath(a, b) + c_fastmath = kernel_fastmath(a, b) + + # Both should produce similar results + torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3) + print(f"✓ {mathop_name} numerical test passed") + + +def run_abs_test(): + """Test that abs correctly maps to fabs (not __fabsf) in generated CUDA code""" + M, N = 128, 128 + block_M, block_N = 32, 32 + + @T.prim_func + def main( + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = T.abs(A[by * block_M + i, bx * block_N + j]) + + kernel = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + source = kernel.get_kernel_source() + print("\n=== Testing abs (maps to fabs) ===") + check_non_fastmath_usage(source, "fabs") + + # Test numerical correctness + a = torch.randn(M, N, device="cuda", dtype=torch.float32) + b = kernel(a) + expected = torch.abs(a) + + torch.testing.assert_close(b, expected, rtol=1e-5, atol=1e-5) + print("✓ abs numerical test passed") + + +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) + + # Test with FAST_MATH enabled + kernel_fastmath = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + + source_fastmath = kernel_fastmath.get_kernel_source() + + print(f"\n=== Testing {mathop_name} (fastmath version) ===") + print("FAST_MATH=True:") + # Strip the __ prefix for checking in the CUDA source + cuda_mathop_name = mathop_name.lstrip("_") + check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) + + # Test numerical correctness + torch_dtype = dtype.as_torch() + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + # Ensure positive values for functions that need them + if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: + a = torch.abs(a) + 0.1 + + b_fastmath = kernel_fastmath(a) + + # Compare with reference implementation + if cuda_mathop_name == "exp": + expected = torch.exp(a) + elif cuda_mathop_name == "log": + expected = torch.log(a) + else: + expected = b_fastmath # Just check compilation works + + torch.testing.assert_close(b_fastmath, expected, rtol=1e-3, atol=1e-3) + print(f"✓ {mathop_name} numerical test passed") + + +@pytest.mark.parametrize( + "name, func", + [ + ("exp", T.exp), + ("exp2", T.exp2), + ("exp10", T.exp10), + ("log", T.log), + ("log2", T.log2), + ("log10", T.log10), + ("sin", T.sin), + ("cos", T.cos), + ("tan", T.tan), + ("sinh", T.sinh), + ("cosh", T.cosh), + ("tanh", T.tanh), + ("atan", T.atan), + ("sqrt", T.sqrt), + ("rsqrt", T.rsqrt), + ("erf", T.erf), + ("floor", T.floor), + ("ceil", T.ceil), + ("trunc", T.trunc), + ("round", T.round), + ("nearbyint", T.nearbyint), + ], +) +@tilelang.testing.requires_cuda +def test_mathops_generate_no_fastmath(name, func): + """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" + run_single_arg_mathop_test(name, func, dtype=T.float32) + print(f"✓ {name} test passed") + + +@pytest.mark.parametrize( + "name, func", + [ + ("pow", T.pow), + ("fmod", T.fmod), + ], +) +@tilelang.testing.requires_cuda +def test_two_arg_mathops_fastmath(name, func): + """Test all two-argument mathops""" + run_two_arg_mathop_test(name, func, dtype=T.float32) + + +@tilelang.testing.requires_cuda +def test_abs_maps_to_fabs(): + """Test that abs correctly maps to fabs""" + run_abs_test() + + +@pytest.mark.parametrize( + "name, func", + [ + ("__exp", T.__exp), + ("__exp10", T.__exp10), + ("__log", T.__log), + ("__log2", T.__log2), + ("__log10", T.__log10), + ("__tan", T.__tan), + ("__cos", T.__cos), + ("__sin", T.__sin), + ], +) +@tilelang.testing.requires_cuda +def test_fastmath_versions(name, func): + """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" + run_fastmath_mathop_test(name, func, dtype=T.float32) + print(f"✓ {name} test passed") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/ir/test_ir_kernel_frame.py b/tilelang/original/testing/python/ir/test_ir_kernel_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a6bbc90ee334ebbf264106e13c1653a4b25caa --- /dev/null +++ b/tilelang/original/testing/python/ir/test_ir_kernel_frame.py @@ -0,0 +1 @@ +# TODO: implement this test for tilelang/language/kernel.py diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_1001.py b/tilelang/original/testing/python/issue/test_tilelang_issue_1001.py new file mode 100644 index 0000000000000000000000000000000000000000..f2315ef21ee31f5cb5edea9ea1893f31c685bc2d --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_1001.py @@ -0,0 +1,34 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, +) +def _cumsum_view_infer_layout(hidden): + num_tokens = T.dynamic("num_tokens") + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens, hidden), T.float]): + with T.Kernel(num_tokens, threads=128) as pid: + smem = T.alloc_shared((hidden,), dtype=T.float32) + T.copy(x[pid, :], smem) + T.cumsum(T.view(smem, (1, hidden)), dim=1) + + return buggy_kernel + + +def test_cumsum_view_infer_layout(): + hidden = 128 + x = torch.randn(1, hidden, device="cuda", dtype=torch.float) + kernel = _cumsum_view_infer_layout(hidden) + kernel(x) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_1008.py b/tilelang/original/testing/python/issue/test_tilelang_issue_1008.py new file mode 100644 index 0000000000000000000000000000000000000000..a35a18449c5474a6122603418ff869cd6c8f235b --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_1008.py @@ -0,0 +1,55 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, +) +def _fill_with_static_region_kernel(): + num_tokens = T.symbolic("num_tokens") + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 + with T.Kernel(num_tokens, threads=128) as _: + T.fill(x[0:128], 0) + + return buggy_kernel + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, +) +def _fill_with_dynamic_region_kernel(): + num_tokens = T.symbolic("num_tokens") + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 + with T.Kernel(num_tokens, threads=128) as _: + a, b = T.alloc_var(T.int), T.alloc_var(T.int) + T.fill(x[a:b], 0) + + return buggy_kernel + + +def test_fill_with_static_region_kernel(): + kernel = _fill_with_static_region_kernel() + x = torch.zeros((256,), dtype=torch.int64, device="cuda") + kernel(x) + + +def test_fill_with_dynamic_region_kernel(): + kernel = _fill_with_dynamic_region_kernel() + x = torch.zeros((256,), dtype=torch.int64, device="cuda") + kernel(x) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_1115.py b/tilelang/original/testing/python/issue/test_tilelang_issue_1115.py new file mode 100644 index 0000000000000000000000000000000000000000..658c126a049e9f8cdfdee44847780e4b9243b3ae --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_1115.py @@ -0,0 +1,47 @@ +import torch +import tilelang +import tilelang.language as T + + +def test_int64_address(): + @tilelang.jit + def set_cache_kernel( + S, + D, + pos_ty="int64", + dtype=T.float32, + ): + @T.prim_func + def main( + pos: T.Tensor( + [ + S, + ], + pos_ty, + ), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32` + value: T.Tensor([S, D], dtype), # type: ignore + cache: T.Tensor([S, D], dtype), # type: ignore + ): + with T.Kernel(S, threads=128) as bx: + slot = pos[bx] + for i in T.Parallel(D): + cache[slot, i] = value[bx, i] + + return main + + D = 2 + S = 10 + cache = torch.rand((S, D), device="cuda", dtype=torch.float32) + value = torch.rand((S, D), device="cuda", dtype=torch.float32) + pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64) + pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32) + kernel_int64 = set_cache_kernel(S, D, "int64") + kernel_int32 = set_cache_kernel(S, D, T.int32) + kernel_int64(pos_int64, value, cache) + torch.testing.assert_close(cache, value) + kernel_int32(pos_int32, value, cache) + torch.testing.assert_close(cache, value) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_1198.py b/tilelang/original/testing/python/issue/test_tilelang_issue_1198.py new file mode 100644 index 0000000000000000000000000000000000000000..e6330e4356e6ebaaf4d62a8cd267de1acb660674 --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_1198.py @@ -0,0 +1,19 @@ +import tilelang.testing +import tilelang.language as T + + +def test_issue_1198(): + @T.prim_func + def foo( + x: T.Buffer( + [ + 32, + ], + T.int32, + ), + ): + pass + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_1210.py b/tilelang/original/testing/python/issue/test_tilelang_issue_1210.py new file mode 100644 index 0000000000000000000000000000000000000000..2e141d7829fa61865d91690b5852faf68d3f3e4b --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_1210.py @@ -0,0 +1,36 @@ +import tilelang +import tilelang.language as T +import tilelang.testing + + +def _make_kernel(M, N): + dtype = T.bfloat16 + + @T.prim_func + def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)): + with T.Kernel(4, threads=1): + A = T.alloc_shared([N], dtype) + B = T.alloc_shared([N], dtype) + + # Regression for a bug where InjectSoftwarePipeline left the loop + # variable as a free var, causing MakePackedAPI to fail + for i in T.Pipelined(4, num_stages=1): + _id = ids[i] + T.copy(KV[_id, :], A) + T.clear(B) + + return fwd_main + + +def test_make_packed_api_no_free_loop_var(): + func = _make_kernel(4, 4) + # Keep warp-specialization/TMA disabled to match the original repro + cfg = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + } + tilelang.compile(func, pass_configs=cfg) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_1237.py b/tilelang/original/testing/python/issue/test_tilelang_issue_1237.py new file mode 100644 index 0000000000000000000000000000000000000000..bb936e4686595e7050a92ad6ffc40b1cd751dcbb --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_1237.py @@ -0,0 +1,23 @@ +import tilelang.testing +from tilelang import language as T + + +def test_issue_1237_dynamic_copy_extent_builds(): + # Repro from debug/1113_issues/copy_dyn.py, adapted as a unit test. + # The goal is to ensure T.copy correctly handles dynamic extents + # (e.g., src slice length vs. static dst buffer size) during prim_func building. + + length = T.symbolic("len", dtype=T.int32) + + @T.prim_func + def sample_kernel(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821 + with T.Kernel(1, threads=32): + buffer_shared = T.alloc_shared((1024,), dtype=T.int32) + T.copy(global_tensor[0:length], buffer_shared) + + # Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute. + _ = sample_kernel + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_814.py b/tilelang/original/testing/python/issue/test_tilelang_issue_814.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f94bd744f0722232e0d7ebcc413ecfec2eb4b8 --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_814.py @@ -0,0 +1,50 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import torch + + +@tilelang.jit +def _tmp_var_kernel(N, block_N, dtype=T.float32): + @T.prim_func + def kernel( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx: + for i in T.Parallel(block_N): + idx = bx * block_N + i + tmp = T.max(A[idx], 1) + B[idx] = tmp / 2 + A[idx] = tmp * 2 + + return kernel + + +def run_tmp_var_test(N=1024, block_N=128): + kernel = _tmp_var_kernel(N, block_N) + + a = torch.randn(N, device="cuda", dtype=torch.float) + b = torch.empty(N, device="cuda", dtype=torch.float) + + a_ref = a.clone() + + kernel(a, b) + + # Reference computation + tmp_ref = torch.maximum(a_ref, torch.tensor(1.0, dtype=torch.float, device="cuda")) + b_ref = tmp_ref / 2 + a_ref = tmp_ref * 2 + + # Validate correctness + tilelang.testing.torch_assert_close(a, a_ref, rtol=1e-2, atol=1e-2) + tilelang.testing.torch_assert_close(b, b_ref, rtol=1e-2, atol=1e-2) + + +def test_issue_814(): + """Test that temporary variables are correctly handled and not over-inlined""" + run_tmp_var_test(N=1024, block_N=128) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_830.py b/tilelang/original/testing/python/issue/test_tilelang_issue_830.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2a909d27e79bb77c6fe52e604ba64550dd5e8c --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_830.py @@ -0,0 +1,79 @@ +# ruff: noqa + +import torch +import tilelang +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit +def _empty_kernel(): + @T.prim_func + def empty_kernel(): + with T.Kernel(1, threads=32) as thread_idx: + pass + + return empty_kernel + + +@tilelang.testing.requires_cuda +def test_empty_kernel_lowering(): + # Ensure a valid CUDA runtime context is current on this thread for the + # target device before using driver API calls. Without this, calls like + # cuModuleLoadData can fail with CUDA_ERROR_INVALID_CONTEXT, especially + # for kernels that don't touch any device memory or streams beforehand + # (e.g., "empty" kernels) and therefore haven't triggered context + # creation implicitly. + torch.cuda.set_device(0) + kernel = _empty_kernel() + kernel() + + +@tilelang.jit +def _empty_with_dead_code_kernel(): + num_tokens = T.dynamic("num_tokens") + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), T.float32]): + with T.Kernel(num_tokens, threads=32) as pid: + y = x[pid] + + return buggy_kernel + + +@tilelang.testing.requires_cuda +def test_empty_with_dead_code_kernel(): + kernel = _empty_with_dead_code_kernel() + x = torch.randn((128,), dtype=torch.float32, device="cuda") + kernel(x) + + +@tilelang.jit +def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False): + @T.prim_func + def kernel_with_tuple_kernel_binding(): + with T.Kernel(1, threads=32) as (pid,): + print(pid) + pass + + @T.prim_func + def kernel_with_scalar_kernel_binding(): + with T.Kernel(1, threads=32) as pid: + print(pid) + pass + + return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding + + +@tilelang.testing.requires_cuda +def test_empty_kernel_with_binding_variants(): + torch.cuda.set_device(0) + kernel = _empty_kernel_with_binding_variants() + kernel() + + tuple_kernel = _empty_kernel_with_binding_variants(use_tuple_binding=True) + tuple_kernel() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_96.py b/tilelang/original/testing/python/issue/test_tilelang_issue_96.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf5c69bd80f262915e4551faed7ff8991993ce1 --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_96.py @@ -0,0 +1,63 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import torch + + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + # changing num_stages to 0 gives correct results + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A[by * block_M, ko * block_K], A_shared) + + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32): + func = matmul(N, N, N, block_M, block_N, block_K) + jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda") + + torch.manual_seed(0) + a = torch.randn(N, N, device="cuda", dtype=torch.float16) + b = torch.randn(N, N, device="cuda", dtype=torch.float16) + + ref_c = a @ b.T + c = jit_kernel(a, b) + + tilelang.testing.torch_assert_close(c, ref_c, rtol=1e-2, atol=0.2) + + +def test_pipeline_large_matrix(): + """Test pipeline stages with large matrix multiplication (4096x4096)""" + run_gemm_pipeline_test(4096) + + +def test_pipeline_small_matrix(): + """Test pipeline stages with smaller matrix multiplication""" + run_gemm_pipeline_test(1024) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/issue/test_tilelang_issue_merge_if.py b/tilelang/original/testing/python/issue/test_tilelang_issue_merge_if.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b1e3082922afe3bf5b3fe1775fe57f4dcc9fd8 --- /dev/null +++ b/tilelang/original/testing/python/issue/test_tilelang_issue_merge_if.py @@ -0,0 +1,35 @@ +import tilelang +from tilelang import tvm as tvm +from tvm.ir import IRModule +import tilelang.testing +import tilelang.language as T + + +def merge_if_test(): + @T.prim_func + def main(): + A = T.alloc_fragment((1,), T.float16) + B = T.alloc_fragment((1,), T.float16) + C = T.alloc_fragment((1,), T.float16) + D = T.alloc_fragment((1,), T.float16) + if A[0] == 0: + A[0] = 0 + if B[0] == 0: + B[0] = 0 + if C[0] == 0: + C[0] = 0 + if D[0] == 0: + D[0] = 0 + + return main + + +def test_merge_if(): + func = merge_if_test() + original_module = IRModule.from_expr(func) + transformed = tilelang.transform.MergeIfStmt()(original_module) + tvm.ir.assert_structural_equal(original_module["main"], transformed["main"], True) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_callback.py b/tilelang/original/testing/python/jit/test_tilelang_jit_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..98b88820cbab58a0bfe4fe237019f12d009c72f9 --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_callback.py @@ -0,0 +1,234 @@ +from tilelang import language as T +import tilelang.testing +import tilelang +from tilelang.engine.callback import register_cuda_postproc_callback +import torch + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @register_cuda_postproc_callback + def tilelang_callback_cuda_postproc(code, _): + code = f"// {stramp}\n" + code + return code + + tilelang.disable_cache() + matmul_kernel = tilelang.compile(program, out_idx=-1) + tilelang.enable_cache() + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1) + + A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() + B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_cutedsl.py b/tilelang/original/testing/python/jit/test_tilelang_jit_cutedsl.py new file mode 100644 index 0000000000000000000000000000000000000000..7c613c4d1cceb04915b54b4a48d68bbb22d24f08 --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_cutedsl.py @@ -0,0 +1,381 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @tvm.register_global_func("tilelang_callback_cutedsl_postproc", override=True) + def tilelang_callback_cutedsl_postproc(code, _): + code = f"# {stramp}\n" + code + return code + + matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl") + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def matmul_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def run_cutedsl_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, target="cutedsl") + + profiler = matmul_kernel.get_profiler() + + cutedsl_latency = profiler.do_bench(func=matmul_kernel) + print(f"CuTeDSL Latency: {cutedsl_latency} ms") + + assert cutedsl_latency is not None + + tvm_latency = profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + +def test_cutedsl_kernel_do_bench(): + run_cutedsl_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cutedsl_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, target="cutedsl") + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_cutedsl_kernel_multi_stream(): + run_cutedsl_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cutedsl_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, target="cutedsl") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_cutedsl_dynamic_shape(): + run_cutedsl_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cutedsl_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cutedsl_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2 + ) + + +def check_hopper(): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_gemm.py b/tilelang/original/testing/python/jit/test_tilelang_jit_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..97391f26f37ad6bf064ed453f221668ebc32f548 --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_gemm.py @@ -0,0 +1,124 @@ +from tilelang import language as T +import tilelang.testing +import tilelang +import torch + + +@tilelang.jit( + out_idx=-1, # create the output tensor during runtime +) +def matmul_kernel_jit( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_kernel_jit( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=0, + num_threads=128, +): + matmul_kernel = matmul_kernel_jit( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() + B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_f16f16f16_nn_kernel_jit(): + run_gemm_kernel_jit( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 128, + 32, + 0, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_gemm_cython.py b/tilelang/original/testing/python/jit/test_tilelang_jit_gemm_cython.py new file mode 100644 index 0000000000000000000000000000000000000000..c5399fc51f1a12aafa5c3dae3e5e5f8e1dff6508 --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -0,0 +1,573 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) + def tilelang_callback_cuda_postproc(code, _): + code = f"// {stramp}\n" + code + return code + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython") + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def run_cython_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + cython_matmul_kernel = tilelang.compile(program, execution_backend="cython") + ctypes_matmul_kernel = tilelang.compile(program, execution_backend="ctypes") + + cython_profiler = cython_matmul_kernel.get_profiler() + ctypes_profiler = ctypes_matmul_kernel.get_profiler() + + cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel) + print(f"cython Latency: {cython_latency} ms") + + # assert ctypes_latency is not None + + tvm_latency = cython_profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel) + print(f"ctypes Latency: {ctypes_latency} ms") + + assert cython_latency is not None + + +def test_cython_kernel_do_bench(): + run_cython_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_cython_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="cython") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_cython_kernel_multi_stream(): + run_cython_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_cython_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="cython") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_cython_dynamic_shape(): + run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_cython_dynamic_shape_with_out_idx( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=-1) + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + + tensor_c = matmul_kernel(tensor_a, tensor_b) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_cython_dynamic_shape_with_out_idx(): + run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def matmul_int_variable( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + offset: T.int32, + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = C_local[i, j] + offset + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads): + program = matmul_int_variable( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads + ) + matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + tensor_c = matmul_kernel(tensor_a, tensor_b, 1) + + tensor_ref_c = torch.matmul(tensor_a, tensor_b).to(out_dtype) + 1 + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, rtol=1e-2, atol=1e-2) + + +def test_matmul_int_variable(): + run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128) + + +def matmul_float_variable( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + offset: T.float32, + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = C_local[i, j] + offset + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads): + program = matmul_float_variable( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads + ) + matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + tensor_c = matmul_kernel(tensor_a, tensor_b, 1.0) + + tensor_ref_c = torch.matmul(tensor_a, tensor_b).to(out_dtype) + 1.0 + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, rtol=1e-2, atol=1e-2) + + +def test_matmul_float_variable(): + run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_nullptr.py b/tilelang/original/testing/python/jit/test_tilelang_jit_nullptr.py new file mode 100644 index 0000000000000000000000000000000000000000..a9edb5e930ac5687aed24dd44cc4a5e8685930d9 --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_nullptr.py @@ -0,0 +1,54 @@ +import torch +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl +import tilelang.language as T +from tilelang.utils import map_torch_type + + +@tl.jit +def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, with_bias=False): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + Bias: T.Tensor((N), accum_dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] += Bias[bx * block_N + j] + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) + b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) + c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) + kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False) + kernel(a, b, c, None) + + +def test_nullptr(): + run_test(1024, 1024, 1024, 128, 128, 32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_nvrtc.py b/tilelang/original/testing/python/jit/test_tilelang_jit_nvrtc.py new file mode 100644 index 0000000000000000000000000000000000000000..b6823b8cca054045b3e751347c73e89a1555af04 --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_nvrtc.py @@ -0,0 +1,506 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) + def tilelang_callback_cuda_postproc(code, _): + code = f"// {stramp}\n" + code + return code + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def run_nvrtc_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") + + profiler = matmul_kernel.get_profiler() + + nvrtc_latency = profiler.do_bench(func=matmul_kernel) + print(f"NVRTC Latency: {nvrtc_latency} ms") + + assert nvrtc_latency is not None + + tvm_latency = profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + +def test_nvrtc_kernel_do_bench(): + run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_nvrtc_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_nvrtc_kernel_multi_stream(): + run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_nvrtc_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_nvrtc_dynamic_shape(): + run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def check_hopper(): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def run_nvrtc_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256): + """Test im2col TMA descriptor functionality in NVRTC backend.""" + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads) + + conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") + + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + out_c = conv_kernel(a, b) + + # Reference implementation using torch.conv2d + def ref_program(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=S, padding=P, dilation=D) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + ref_c = ref_program(a, b) + tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_nvrtc_im2col_tma_desc(): + """Test im2col TMA descriptor with NVRTC backend.""" + if not check_hopper(): + import pytest + + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") + + # Small test case for im2col TMA descriptor + run_nvrtc_im2col_tma_desc( + N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256 + ) + + +def test_nvrtc_l2_persistent_map(): + """Test L2 persistent cache annotation with elementwise add.""" + from tilelang.language import annotate_l2_hit_ratio + + M = 1024 + N = 1024 + + @tilelang.jit(out_idx=[-1], execution_backend="nvrtc") + def elementwise_add_with_l2_cache( + M, + N, + block_size=256, + dtype=T.float32, + ): + @T.prim_func + def kernel( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(M * N // block_size, threads=block_size) as bx: + # Annotate L2 persistent cache for buffer B + # B will be accessed multiple times and benefit from L2 caching + annotate_l2_hit_ratio({B: 0.8}) + + for i in T.serial(block_size): + idx = bx * block_size + i + if idx < M * N: + row = idx // N + col = idx % N + C[row, col] = A[row, col] + B[row, col] + + return kernel + + # Compile the kernel + kernel = elementwise_add_with_l2_cache(M, N) + + # Create test tensors + a = torch.randn(M, N, dtype=torch.float32).cuda() + b = torch.randn(M, N, dtype=torch.float32).cuda() + + # Run kernel with out_idx=[-1], C is returned not passed in + c = kernel(a, b) + + # Verify correctness + ref_c = a + b + tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5) + + print("L2 persistent map test passed!") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_parcompile.py b/tilelang/original/testing/python/jit/test_tilelang_jit_parcompile.py new file mode 100644 index 0000000000000000000000000000000000000000..56201e1cc5aabf0b6e3e56e691ba4d624ada901e --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_parcompile.py @@ -0,0 +1,75 @@ +import tilelang.testing +import tilelang +import torch +from tilelang import language as T + + +@tilelang.jit( + out_idx=-1, # create the output tensor during runtime + verbose=True, +) +def matmul_kernel_jit( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A=False, + trans_B=True, + in_dtype=T.float16, + out_dtype=T.float32, + accum_dtype=T.float32, + num_stages=2, + threads=128, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def test_par_compile(): + configs = [ + (1024, 1024, 1024, 128, 128, 32), + (2048, 2048, 2048, 256, 256, 64), + (4096, 4096, 4096, 64, 64, 128), + ] + kernels = matmul_kernel_jit.par_compile(configs) + for (M, N, K, _, _, _), kernel in zip(configs, kernels): + A = torch.randn(M, K, dtype=torch.float16).cuda() + B = torch.randn(N, K, dtype=torch.float16).cuda() + ref = (A @ B.T).float() + C = kernel(A, B) + tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/tilelang/original/testing/python/jit/test_tilelang_jit_tvm_ffi.py new file mode 100644 index 0000000000000000000000000000000000000000..a0df2719213676b0a27e018a80439da0af7fa53f --- /dev/null +++ b/tilelang/original/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -0,0 +1,454 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def run_tvm_ffi_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") + + profiler = matmul_kernel.get_profiler() + + tvm_ffi_latency = profiler.do_bench(func=matmul_kernel) + print(f"tvm_ffi Latency: {tvm_ffi_latency} ms") + + assert tvm_ffi_latency is not None + + tvm_latency = profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + +def test_tvm_ffi_kernel_do_bench(): + run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_tvm_ffi_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_tvm_ffi_kernel_multi_stream(): + run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + +def run_tvm_ffi_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_tvm_ffi_dynamic_shape(): + run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) + + run_tvm_ffi_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2 + ) + + +def check_hopper(): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def run_tvm_ffi_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256): + """Test im2col TMA descriptor functionality in tvm_ffi backend.""" + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads) + + conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") + + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + out_c = conv_kernel(a, b) + + # Reference implementation using torch.conv2d + def ref_program(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=S, padding=P, dilation=D) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + ref_c = ref_program(a, b) + tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_tvm_ffi_im2col_tma_desc(): + """Test im2col TMA descriptor with tvm_ffi backend.""" + if not check_hopper(): + import pytest + + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") + + # Small test case for im2col TMA descriptor + run_tvm_ffi_im2col_tma_desc( + N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256 + ) + + +def test_tvm_ffi_l2_persistent_map(): + """Test L2 persistent cache annotation with elementwise add.""" + from tilelang.language import annotate_l2_hit_ratio + + M = 1024 + N = 1024 + + @tilelang.jit(out_idx=[-1], execution_backend="tvm_ffi") + def elementwise_add_with_l2_cache( + M, + N, + block_size=256, + dtype=T.float32, + ): + @T.prim_func + def kernel( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(M * N // block_size, threads=block_size) as bx: + # Annotate L2 persistent cache for buffer B + # B will be accessed multiple times and benefit from L2 caching + annotate_l2_hit_ratio({B: 0.8}) + + for i in T.serial(block_size): + idx = bx * block_size + i + if idx < M * N: + row = idx // N + col = idx % N + C[row, col] = A[row, col] + B[row, col] + + return kernel + + # Compile the kernel + kernel = elementwise_add_with_l2_cache(M, N) + + source = kernel.get_host_source() + assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" + ) + assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + ) + + # Create test tensors + a = torch.randn(M, N, dtype=torch.float32).cuda() + b = torch.randn(M, N, dtype=torch.float32).cuda() + + # Run kernel with out_idx=[-1], C is returned not passed in + c = kernel(a, b) + + # Verify correctness + ref_c = a + b + tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5) + + print("L2 persistent map test passed!") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..97d050b73012f30a11feb442ab8147493a494604 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -0,0 +1,229 @@ +import torch +import torch.backends +from tilelang import tvm as tvm +import tilelang.testing +from tvm import DataType +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.bfloat16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + is_float8 = in_dtype in [ + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, + ] + if out_dtype == T.int32 or is_float8: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile(matmul, out_idx=[2]) + profiler = kernel.get_profiler() + + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + + C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) + + C = kernel(A, B) + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype) + tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 0) +def test_assert_tl_matmul_bfloat16(): + assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32) + + +if __name__ == "__main__": + # tilelang.testing.main() + test_assert_tl_matmul_bfloat16() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_element_wise_add.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_element_wise_add.py new file mode 100644 index 0000000000000000000000000000000000000000..501b38fda8c6e11b59a7df424b00434454e25e39 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_element_wise_add.py @@ -0,0 +1,109 @@ +import tilelang.testing +from tilelang import language as T +import torch + + +def elementwise_add( + M, + N, + block_M, + block_N, + in_dtype, + out_dtype, + threads, +): + @T.prim_func + def main( + A: T.Tensor((M, N), in_dtype), + B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + + for local_y, local_x in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + + C[y, x] = A[y, x] + B[y, x] + + return main + + +def run_elementwise_add( + M, + N, + in_dtype, + out_dtype, + block_M, + block_N, + num_threads=128, +): + program = elementwise_add( + M, + N, + block_M, + block_N, + in_dtype, + out_dtype, + num_threads, + ) + + kernel = tilelang.compile(program, out_idx=[2]) + profiler = kernel.get_profiler() + + def ref_program(A, B): + C = torch.add(A, B) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_elementwise_add_f32(): + run_elementwise_add( + 512, + 1024, + T.float32, + T.float32, + 128, + 256, + ) + + +def test_elementwise_add_f16(): + run_elementwise_add( + 512, + 1024, + T.float16, + T.float16, + 128, + 256, + ) + + +def test_elementwise_add_i32(): + run_elementwise_add( + 512, + 1024, + T.int32, + T.int32, + 128, + 256, + ) + + +def test_elementwise_add_f32f16(): + run_elementwise_add( + 512, + 1024, + T.float32, + T.float16, + 128, + 256, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..276083b26252070bb614fc1a2d8c09a748b4a60a --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py @@ -0,0 +1,62 @@ +import torch +import tilelang.testing +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype): + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by): + A_shared = T.alloc_shared((bM, bK), in_dtype) + B_shared = T.alloc_shared((bN, bK), in_dtype) + C_local = T.alloc_fragment((bM, bN), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, bK), num_stages=3): + T.copy(A[by * bM, k * bK], A_shared) + T.copy(B[bx * bN, k * bK], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * bM, bx * bN]) + + return main + + +def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype): + func = matmul_nt(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile(func, out_idx=-1) + + A = torch.randn(M, K).to(map_torch_type(in_dtype)).cuda() + B = torch.randn(N, K).to(map_torch_type(in_dtype)).cuda() + + C = kernel(A, B) + + ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype)) + print(C) + print(ref_c) + diff = calc_diff(C, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9) +def test_assert_matmul(): + assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e4m3fn, T.float32, T.float32) + assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e5m2, T.float32, T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba369b6b966f7a32e4174f63b510c1d04b0bc49 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -0,0 +1,229 @@ +import torch +import torch.backends +from tilelang import tvm as tvm +import tilelang.testing +from tvm import DataType +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + is_float8 = in_dtype in [ + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, + ] + if out_dtype == T.int32 or is_float8: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile(matmul, out_idx=[2]) + profiler = kernel.get_profiler() + + src_code = kernel.get_kernel_source() + print(src_code) + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + + C = kernel(A, B) + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 9) +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py new file mode 100644 index 0000000000000000000000000000000000000000..7b757992a714a58f54f2b0010818b1b258f42d70 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py @@ -0,0 +1,172 @@ +import torch +import torch.backends +import tilelang.testing +from tilelang import tvm as tvm +from tvm import DataType +import tilelang.language as T +from tilelang import JITKernel +from tilelang.transform.simplify import apply_simplify +from tilelang.utils.tensor import map_torch_type +from typing import Optional + +tilelang.testing.set_random_seed(0) + + +def gemv_simt( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + trans_A: bool, + trans_B: bool, + with_bias: bool = False, + n_partition: Optional[int] = 4, + reduce_thread: Optional[int] = 32, +): + assert n_partition is not None, "n_partition must be provided" + assert reduce_thread is not None, ( + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) + + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + + block_K = reduce_thread * micro_size_k + + A_shape = (M, K) + B_shape = (N, K) + Bias_shape = (N,) + C_shape = (M, N) + + dp4a_size = 4 + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor(C_shape, out_dtype), + ): + with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_local = T.alloc_local((micro_size_k,), in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.get_thread_binding(0) + ni = T.get_thread_binding(1) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k): + B_local[v] = B[ + bx * n_partition + ni, + ko * block_K + kr * micro_size_k + v, + ] + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + ) + ) + if kr == 0: + if with_bias: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] + else: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return apply_simplify(main) + + +def evaluate_gemv_simt( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + trans_A: bool = False, + trans_B: bool = True, + with_bias: bool = False, +): + program = gemv_simt(M, N, K, in_dtype, out_dtype, accum_dtype, trans_A, trans_B, with_bias) + + kernel = JITKernel(program, target="cuda") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + Bias = torch.randint(-128, 128, (N,), dtype=torch.int32).to(accum_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + Bias = torch.randn(N).to(accum_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + Bias = torch.randn(N).to(accum_dtype).cuda() - 0.5 + + C = torch.zeros(M, N).to(out_dtype).cuda() + + if with_bias: + kernel(A, B, Bias, C) + else: + kernel(A, B, C) + + ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32)) + if with_bias: + ref_c += Bias.to(torch.float32) + + print(C) + print(ref_c) + tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 9) +def test_gemv_simt(): + evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc95e98ada55e145a5e72317630ef0587452f76 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -0,0 +1,539 @@ +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=0, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile(program, out_idx=[2]) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 128, + 32, + 0, + ) + + +def test_gemm_f16f16f32_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float16, + T.float16, + T.float32, + 128, + 128, + 32, + ) + + +def test_gemm_bf16bf16f32_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.bfloat16, + T.bfloat16, + T.float32, + 128, + 128, + 32, + ) + + +def test_gemm_f32f32f32_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + T.float32, + T.float32, + T.float32, + 64, + 128, + 32, + ) + + +def test_gemm_f16f16f16_tn(): + run_gemm( + 512, + 1024, + 768, + True, + False, + T.float16, + T.float16, + T.float16, + 128, + 128, + 32, + 0, + ) + + +def test_gemm_f16f16f16_nt(): + run_gemm( + 512, + 1024, + 768, + False, + True, + T.float16, + T.float16, + T.float16, + 128, + 128, + 32, + 0, + ) + + +def test_gemm_i8i8i32_nt(): + run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64) + + +def test_gemm_i8i8i32_tn(): + run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64) + + +def test_gemm_f64f64f64_nt(): + run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16) + + +def test_gemm_f32f32f32_nt(): + run_gemm( + 512, + 1024, + 768, + False, + True, + T.float32, + T.float32, + T.float32, + 64, + 128, + 32, + ) + + +def test_gemm_f32f32f32_tn(): + run_gemm( + 512, + 1024, + 768, + True, + False, + T.float32, + T.float32, + T.float32, + 64, + 128, + 32, + ) + + +def test_pad_aligned_f16f16f16_nn(): + run_gemm( + 512 - 8, + 1024 - 32, + 768 - 24, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def test_pad_f16f16f16_nn(): + run_gemm( + 512 - 9, + 1024 - 7, + 768 - 5, + False, + False, + T.float16, + T.float16, + T.float16, + 128, + 256, + 32, + 2, + ) + + +def test_pad_f16f16f32_nn(): + run_gemm( + 512 + 19, + 1024 + 17, + 768 + 15, + False, + False, + T.float16, + T.float16, + T.float32, + 128, + 64, + 32, + ) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + B_local = T.alloc_fragment(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + T.copy(B_shared, B_local) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_local) + T.gemm(A_shared, B_local, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=1, + num_threads=128, +): + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile(program, out_idx=[2]) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float) + B = B.to(torch.float) + C = torch.matmul(A, B) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +# WGMMA only supports B in shared +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_gemm_f16f16f16_sr(): + run_gemm_sr( + 512, + 1024, + 768, + False, + True, + T.float16, + T.float16, + T.float16, + 128, + 128, + 32, + 0, + ) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") + A_local = T.alloc_fragment(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + T.copy(A_shared, A_local) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(A_shared, A_local) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_local, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=1, + num_threads=128, +): + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile(program, out_idx=[2]) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +# Register source A operand GMMAs must have K-major A layout. +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_gemm_f16f16f16_rs(): + run_gemm_rs( + 512, + 1024, + 768, + True, + False, + T.float16, + T.float16, + T.float16, + 128, + 128, + 32, + 0, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1b75ebc590ed9189b71f46e8db40734de879db --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -0,0 +1,241 @@ +import torch +import torch.backends +from tilelang import tvm as tvm +import tilelang.testing +from tvm import DataType +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.bfloat16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + is_float8 = in_dtype in [ + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, + ] + if out_dtype == T.int32 or is_float8: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile(matmul, out_idx=[2]) + profiler = kernel.get_profiler() + + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + + C = kernel(A, B) + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype) + tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 0) +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) + assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 0) +def test_assert_tl_matmul_bfloat16(): + assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 9) +def test_assert_tl_matmul_fp8(): + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_simt.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_simt.py new file mode 100644 index 0000000000000000000000000000000000000000..584aa854a715d61ab306a7be153fdd35973e1cd9 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_simt.py @@ -0,0 +1,170 @@ +import torch +import torch.backends +import tilelang.testing +from tilelang import tvm as tvm +from tvm import DataType +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.transform import simplify_prim_func + +tilelang.testing.set_random_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul_simt( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + # This is a debug config + block_size_x = 8 + block_size_y = 8 + thread_row_tiles = 16 + thread_col_tiles = 16 + chunk = 16 + + shared_scope = "shared" + + block_M = block_size_x * thread_row_tiles + block_N = block_size_y * thread_col_tiles + block_K = chunk + + # Pipeline Stage + + A_shape = (M, K) + B_shape = (N, K) + C_shape = (M, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + + threads = thread_row_tiles * thread_col_tiles + local_size_a = block_M // thread_row_tiles + local_size_b = block_N // thread_col_tiles + local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles) + + micro_size_k = 128 // DataType(in_dtype).bits + dp4a_size = 4 + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor(C_shape, out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + + A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype) + B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype) + C_local = T.alloc_local((local_size_c,), accum_dtype) + + tid = T.get_thread_binding() + + warp_m = tid % thread_row_tiles + warp_n = tid // thread_row_tiles + + T.clear(C_local) + + for ko in T.serial(K // block_K): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial((block_K // micro_size_k)): + for i in T.serial(local_size_a): + for mk in T.vectorized(micro_size_k): + A_local[i, mk] = A_shared[warp_m * local_size_a + i, ki * micro_size_k + mk] + + for i in T.serial(local_size_b): + for mk in T.vectorized(micro_size_k): + B_local[i, mk] = B_shared[warp_n * local_size_b + i, ki * micro_size_k + mk] + + for i, j in T.grid(local_size_a, local_size_b): + for mk in T.serial(micro_size_k // dp4a_size): + if use_dp4a: + T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], C_local[i * local_size_b + j]) + else: + for dp4a_idx in T.serial(dp4a_size): + C_local[i * local_size_b + j] += ( + A_local[i, mk * dp4a_size + dp4a_idx] * B_local[j, mk * dp4a_size + dp4a_idx] + ) + + for i, j in T.grid(local_size_a, local_size_b): + C[by * block_M + warp_m * local_size_a + i, bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile(matmul, out_idx=[2]) + profiler = kernel.get_profiler() + + src_code = kernel.get_kernel_source() + print(src_code) + # src_code is the generated cuda source + assert src_code is not None + + if in_dtype == T.int8: + A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = kernel(A, B) + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) + assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py new file mode 100644 index 0000000000000000000000000000000000000000..1f76600325369d7e21a0b9c275042ea9f7ea60a2 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py @@ -0,0 +1,86 @@ +import tilelang.testing +import tilelang +import tilelang.language as T +import torch + + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K * 2), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N * 2), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Clear local accumulation + T.clear(C_local) + T.clear(B_shared) + T.clear(A_shared) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + # Copy tile of A + # T.copy(A[by * block_M, ko * block_K], A_shared) + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k + block_K] = A[by * block_M + i, ko * block_K + k] + + # Copy tile of B + # T.copy(B[ko * block_K, bx * block_N], B_shared) + for i, k in T.Parallel(block_K, block_N): + B_shared[i, k] = B[ko * block_K + i, bx * block_N + k] + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared[:, block_K:], B_shared[0:block_K, 0:block_N], C_local) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int, block_K: int): + # 1. Define the kernel (matmul) and compile/lower it into an executable module + func = matmul(M, N, K, block_M, block_N, block_K) + + # 2. Compile the kernel into a torch function + # out_idx specifies the index of the output buffer in the argument list + # if out_idx is specified, the tensor will be created during runtime + # target currently can be "cuda" or "hip" or "cpu". + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + print(c) + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(7, 5) +def test_tilelang_kernel_gemm_with_stride(): + run_gemm_with_stride_ss(128, 128, 64, 32, 32, 32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemv_simt.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemv_simt.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a5c8249239c6002028f101b037ff6295cb74f8 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_gemv_simt.py @@ -0,0 +1,179 @@ +import torch +import torch.backends +import tilelang.testing +from tilelang import tvm as tvm +from tvm import DataType +import tilelang.language as T +from tilelang import JITKernel +from tilelang.transform.simplify import apply_simplify +from tilelang.utils.tensor import map_torch_type +from typing import Optional + +tilelang.testing.set_random_seed(0) + + +def gemv_simt( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + trans_A: bool, + trans_B: bool, + with_bias: bool = False, + n_partition: Optional[int] = 4, + reduce_thread: Optional[int] = 32, +): + assert n_partition is not None, "n_partition must be provided" + assert reduce_thread is not None, ( + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) + + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + + block_K = reduce_thread * micro_size_k + + A_shape = (M, K) + B_shape = (N, K) + Bias_shape = (N,) + C_shape = (M, N) + + dp4a_size = 4 + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor(C_shape, out_dtype), + ): + with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_local = T.alloc_local((micro_size_k,), in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.get_thread_binding(0) + ni = T.get_thread_binding(1) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k): + B_local[v] = B[ + bx * n_partition + ni, + ko * block_K + kr * micro_size_k + v, + ] + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + ) + ) + if kr == 0: + if with_bias: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] + else: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return apply_simplify(main) + + +def evaluate_gemv_simt( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + trans_A: bool = False, + trans_B: bool = True, + with_bias: bool = False, +): + program = gemv_simt(M, N, K, in_dtype, out_dtype, accum_dtype, trans_A, trans_B, with_bias) + + kernel = JITKernel(program, target="cuda") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + Bias = torch.randint(-128, 128, (N,), dtype=torch.int32).to(accum_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + Bias = torch.randn(N).to(accum_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + Bias = torch.randn(N).to(accum_dtype).cuda() - 0.5 + + C = torch.zeros(M, N).to(out_dtype).cuda() + + if with_bias: + kernel(A, B, Bias, C) + else: + kernel(A, B, C) + + ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32)) + if with_bias: + ref_c += Bias.to(torch.float32) + ref_c = ref_c.to(out_dtype) + print(C) + print(ref_c) + tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 0) +def test_gemv_simt(): + evaluate_gemv_simt(1, 1024, 1024, T.float16, T.float16, T.float16, with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.int8, T.int32, T.int32, with_bias=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(8, 9) +def test_gemv_simt_fp8(): + evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/tilelang/original/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..9d60e5229ad455d2c57c98774e9d82279fa5b691 --- /dev/null +++ b/tilelang/original/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py @@ -0,0 +1,408 @@ +import torch +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +from tilelang.intrinsics import ( + make_mma_swizzle_layout as make_swizzle_layout, +) + +from tilelang.intrinsics.mma_macro_generator import ( + INT4TensorCoreIntrinEmitter, + INT4TensorCoreIntrinEmitterWithLadderTransform, +) +from tilelang.transform import simplify_prim_func + +tilelang.testing.set_random_seed(42) + + +# @simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile( + matmul, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True, + }, + ) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler() + + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + C = kernel(compressed_A, compressed_B) + print(C) + latency = profiler.do_bench() + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul_correctness(): + assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(128, 128, 64, T.int8, T.int32, T.int32) + + +@simplify_prim_func +def tl_matmul_weight_only_transform( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + K = K // 2 + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + import bitblas + + matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) + kernel = tilelang.compile(matmul, out_idx=[2]) + profiler = kernel.get_profiler() + + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + transform_b = 3 + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype=T.int8, + storage_dtype=T.int8, + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + LB = ladder_permutate(compressed_B.cpu()).cuda() + C = kernel(compressed_A, LB) + + latency = profiler.do_bench() + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_package("bitblas") +@tilelang.testing.requires_llvm +def test_assert_tl_matmul_weight_only_transform(): + assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, T.int8, T.int32, T.int32) + + +if __name__ == "__main__": + # tilelang.testing.main() + assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) diff --git a/tilelang/original/testing/python/language/test_tilelang_capture.py b/tilelang/original/testing/python/language/test_tilelang_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..47fec999a2531f0f8d82dd5d9145175827e47e6f --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_capture.py @@ -0,0 +1,41 @@ +import tilelang.language as T +import tilelang.testing +import torch +import weakref +import gc + + +def test_tilelang_capture(): + @tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + def get_dummy_kernel(): + @T.prim_func + def dummy_kernel( + a: T.Tensor[(1,), T.float32], + ): + with T.Kernel(1) as _: + a[0] = 1 + + return dummy_kernel + + a = torch.randn(1, 1024) + a_weak = weakref.ref(a) + _kernel = get_dummy_kernel() + del a + torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() + a_upgrade = a_weak() + assert a_upgrade is None, "A is not garbage collected" + + # use objgraph to debug + # if a_upgrade is not None: + # objgraph.show_backrefs([a_upgrade], max_depth=5) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_intimm.py b/tilelang/original/testing/python/language/test_tilelang_intimm.py new file mode 100644 index 0000000000000000000000000000000000000000..46c2c79873ebd079aac21cf6f50ea1738f0105f5 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_intimm.py @@ -0,0 +1,28 @@ +import tilelang +import tilelang.testing +import tilelang.language as T + + +def test_tilelang_intimm(): + T.int32(0x7FFFFFFF) + T.int32(-0x7FFFFFFF - 1) + T.uint32(0xFFFFFFFF) + T.int64(0x7FFFFFFFFFFFFFFF) + T.int64(-0x7FFFFFFFFFFFFFFF - 1) + T.uint64(0xFFFFFFFFFFFFFFFF) + + a = T.int32() + a & 0x7FFFFFFF + + a = T.uint32() + a & 0xFFFFFFFF + + a = T.int64() + a & 0x7FFFFFFFFFFFFFFF + + a = T.uint64() + a & T.uint64(0xFFFFFFFFFFFFFFFF) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_alias.py b/tilelang/original/testing/python/language/test_tilelang_language_alias.py new file mode 100644 index 0000000000000000000000000000000000000000..48fe1ac4d8d377394e2321bb28ddb83f7eca9824 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_alias.py @@ -0,0 +1,57 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + X_shared = A_shared[:block_M, :block_K] + X_local = C_local[:block_M, :block_K] + T.clear(X_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + # Copy tile of A + # This is a sugar syntax for parallelized copy + aliased_offset = T.int32() + T.let(aliased_offset, ko * block_K) + T.copy(A[by * block_M, aliased_offset], X_shared) + + # Demonstrate parallelized copy from global to shared for B + T.copy(B[bx * block_N, ko * block_K], B_shared[:block_N, :block_K]) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(X_shared, B_shared, C_local, transpose_B=True) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) + kernel = tilelang.compile(program, out_idx=[2], target="cuda") + kernel.run_once() + + +def test_matmul(): + run_matmul(1024, 1024, 1024, 128, 128, 32) + + +if __name__ == "__main__": + test_matmul() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_all_of.py b/tilelang/original/testing/python/language/test_tilelang_language_all_of.py new file mode 100644 index 0000000000000000000000000000000000000000..db694d337620fb8d2e76e60d4931d7dfae39593e --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_all_of.py @@ -0,0 +1,315 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import torch + + +def ref_program(A, B, BlockMask, block_M, block_N, block_K): + M, K = A.shape + N = B.shape[1] + ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + for i in range(M // block_M): + for j in range(N // block_N): + accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) + for k in range(K // block_K): + if torch.all(BlockMask[i, j, k]): + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) + return ref_c + + +def blocksparse_matmul_global( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + dtype=T.float16, + accum_dtype=T.float32, +): + block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if T.all_of(BlockMask[by, bx, k, :]): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def blocksparse_matmul_shared( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + dtype=T.float16, + accum_dtype=T.float32, +): + block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + block_mask_shared = T.alloc_shared(condition_dim, "bool") + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + for i in T.serial(condition_dim): + block_mask_shared[i] = BlockMask[by, bx, k, i] + # or T.all_of(block_mask_local[0:condition_dim]) + # or T.all_of(block_mask_local[:]) + if T.all_of(block_mask_shared): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def blocksparse_matmul_local( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + dtype=T.float16, + accum_dtype=T.float32, +): + block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + block_mask_local = T.alloc_local(condition_dim, "bool") + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + for i in T.serial(condition_dim): + block_mask_local[i] = BlockMask[by, bx, k, i] + # or T.all_of(block_mask_local[0:condition_dim]) + # or T.all_of(block_mask_local[:]) + if T.all_of(block_mask_local): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2): + block_M = 128 + block_N = 128 + block_K = 32 + num_stages = 2 + thread_num = 128 + enable_rasteration = True + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + func = blocksparse_matmul_global( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + ) + kernel = tilelang.compile(func, out_idx=-1) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim) + # random set the last dimension to be False + block_mask[:, :, :, 0] = False + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2): + block_M = 128 + block_N = 128 + block_K = 32 + num_stages = 2 + thread_num = 128 + enable_rasteration = True + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + func = blocksparse_matmul_shared( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + ) + kernel = tilelang.compile( + func, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim) + # random set the last dimension to be False + block_mask[:, :, :, 0] = False + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2): + block_M = 128 + block_N = 128 + block_K = 32 + num_stages = 2 + thread_num = 128 + enable_rasteration = True + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + func = blocksparse_matmul_local( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + ) + kernel = tilelang.compile( + func, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim) + # random set the last dimension to be False + block_mask[:, :, :, 0] = False + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def test_block_sparse_matmul_global(): + run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2) + + +def test_block_sparse_matmul_shared(): + run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2) + + +def test_block_sparse_matmul_local(): + run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_alloc.py b/tilelang/original/testing/python/language/test_tilelang_language_alloc.py new file mode 100644 index 0000000000000000000000000000000000000000..883f65c3cb0487ff7a7533684c660089f960fc56 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_alloc.py @@ -0,0 +1,162 @@ +import tilelang.testing +from tilelang import language as T + + +def alloc_var( + N, + block_N, + dtype, +): + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared([block_N], dtype) + tmp = T.alloc_var(dtype) + tmp = 1 # noqa: F841 + T.copy(A[bx * block_N], A_shared) + T.copy(A_shared, B[bx * block_N]) + + return main + + +def run_alloc_var( + N, + block_N, + dtype, + min=None, + max=None, +): + program = alloc_var(N, block_N, dtype) + + kernel = tilelang.compile(program, out_idx=[1]) + code = kernel.get_kernel_source() + assert "tmp =" in code + + +def test_alloc_var(): + run_alloc_var(1024, 128, T.float16) + + +def alloc_var_add( + N, + block_N, + dtype, +): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared([block_N], dtype) + tmp = T.alloc_var(dtype) + tmp = 1 # noqa: F841 + T.copy(A[bx * block_N], A_shared) + for i in T.Parallel(block_N): + A_shared[i] = A_shared[i] + tmp + T.copy(A_shared, B[bx * block_N]) + + return main + + +def run_alloc_var_add( + N, + block_N, + dtype, +): + program = alloc_var_add(N, block_N, dtype) + + kernel = tilelang.compile(program, out_idx=[1]) + code = kernel.get_kernel_source() + assert "tmp =" in code + + +def test_alloc_var_add(): + run_alloc_var_add(1024, 128, T.float16) + + +def alloc_var_with_initializer( + N, + block_N, + dtype, + init_value, +): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + tmp = T.alloc_var(dtype, init_value) + T.copy(A[bx * block_N], B[bx * block_N]) + for i in T.Parallel(block_N): + B[bx * block_N + i] = tmp + + return main + + +def run_alloc_var_with_initializer( + N, + block_N, + dtype, + init_value, +): + program = alloc_var_with_initializer(N, block_N, dtype, init_value) + + kernel = tilelang.compile(program, out_idx=[1]) + code = kernel.get_kernel_source() + assert f"= {init_value};" in code + + +def test_alloc_var_with_initializer(): + run_alloc_var_with_initializer(256, 64, T.int32, 5) + + +def alloc_multi_vars_with_initializer( + N, + block_N, + dtype, +): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + tmp0 = T.alloc_var(dtype, 1) + tmp1 = T.alloc_var(dtype, 2) + T.copy(A[bx * block_N], B[bx * block_N]) + for i in T.Parallel(block_N): + B[bx * block_N + i] = tmp0 + tmp1 + + return main + + +def run_alloc_multi_vars_with_initializer( + N, + block_N, + dtype, +): + program = alloc_multi_vars_with_initializer(N, block_N, dtype) + + kernel = tilelang.compile(program, out_idx=[1]) + code = kernel.get_kernel_source(kernel_only=True) + assert code.count("= 1;") == 1 + assert code.count("= 2;") == 1 + + +def test_alloc_multi_vars_with_initializer(): + run_alloc_multi_vars_with_initializer(256, 64, T.int32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_annot.py b/tilelang/original/testing/python/language/test_tilelang_language_annot.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9aeeac6b137189655ef21c6740572077b339ba --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_annot.py @@ -0,0 +1,74 @@ +import tilelang +import tilelang.language as T +import tilelang.testing +import torch + + +def test_tensor_annot_mul(): + @tilelang.jit + def example_tensor_annot(): + n = T.symbolic("n") + + @T.prim_func + def kernel( + A: T.Tensor((n * 4,), T.int32), + ): + with T.Kernel(1) as _: + for i in range(n * 4): + A[i] = 0 + + return kernel + + ker = example_tensor_annot() + A = torch.arange(16, dtype=torch.int32, device="cuda") + ker(A) + expected = torch.zeros(16, dtype=torch.int32, device="cuda") + assert torch.equal(A, expected) + + +def test_tensor_annot_add(): + @tilelang.jit + def example_tensor_annot(): + n = T.symbolic("n") + + @T.prim_func + def kernel( + A: T.Tensor((n + 1,), T.int32), + ): + with T.Kernel(1) as _: + for i in range(n + 1): + A[i] = 0 + + return kernel + + ker = example_tensor_annot() + A = torch.arange(16, dtype=torch.int32, device="cuda") + ker(A) + expected = torch.zeros(16, dtype=torch.int32, device="cuda") + assert torch.equal(A, expected) + + +def test_tensor_annot_mul_add(): + @tilelang.jit + def example_tensor_annot(): + n = T.symbolic("n") + + @T.prim_func + def kernel( + A: T.Tensor((n * 3 + 1,), T.int32), + ): + with T.Kernel(1) as _: + for i in range(n * 3 + 1): + A[i] = 0 + + return kernel + + ker = example_tensor_annot() + A = torch.arange(16, dtype=torch.int32, device="cuda") + ker(A) + expected = torch.zeros(16, dtype=torch.int32, device="cuda") + assert torch.equal(A, expected) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_annotate_safe_value.py b/tilelang/original/testing/python/language/test_tilelang_language_annotate_safe_value.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8239a15705a260a07cf370b321be3bebe12f27 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_annotate_safe_value.py @@ -0,0 +1,50 @@ +import tilelang +import tilelang.language as T +import tilelang.testing +import torch + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_copy(M, N, block_M, block_N, dtype=T.float16, pad_value=0): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.annotate_safe_value({A: pad_value}) + for i, j in T.Parallel(block_M, block_N): + A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j] + + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = A_shared[i, j] + + return main + + +def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16, pad_value=0): + program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value) + kernel = tilelang.compile( + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + ref_b = torch.zeros_like(a) + for i in range(M): + if i >= 10: + ref_b[i, :] = a[i - 10, :] + else: + ref_b[i, :] = pad_value + torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy(): + run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, pad_value=10) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_any_of.py b/tilelang/original/testing/python/language/test_tilelang_language_any_of.py new file mode 100644 index 0000000000000000000000000000000000000000..74db94f7c29502bd2898650cf5ac3cfcbef1e4b6 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_any_of.py @@ -0,0 +1,315 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import torch + + +def ref_program(A, B, BlockMask, block_M, block_N, block_K): + M, K = A.shape + N = B.shape[1] + ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + for i in range(M // block_M): + for j in range(N // block_N): + accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) + for k in range(K // block_K): + if torch.any(BlockMask[i, j, k]): + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) + return ref_c + + +def blocksparse_matmul_global( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + dtype=T.float16, + accum_dtype=T.float32, +): + block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if T.any_of(BlockMask[by, bx, k, :]): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def blocksparse_matmul_shared( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + dtype=T.float16, + accum_dtype=T.float32, +): + block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + block_mask_shared = T.alloc_shared(condition_dim, "bool") + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + for i in T.serial(condition_dim): + block_mask_shared[i] = BlockMask[by, bx, k, i] + # or T.any_of(block_mask_local[0:condition_dim]) + # or T.any_of(block_mask_local[:]) + if T.any_of(block_mask_shared): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def blocksparse_matmul_local( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + dtype=T.float16, + accum_dtype=T.float32, +): + block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + block_mask_local = T.alloc_local(condition_dim, "bool") + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + for i in T.serial(condition_dim): + block_mask_local[i] = BlockMask[by, bx, k, i] + # or T.any_of(block_mask_local[0:condition_dim]) + # or T.any_of(block_mask_local[:]) + if T.any_of(block_mask_local): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2): + block_M = 128 + block_N = 128 + block_K = 32 + num_stages = 2 + thread_num = 128 + enable_rasteration = True + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + func = blocksparse_matmul_global( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + ) + kernel = tilelang.compile(func, out_idx=-1) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim) + # random set the last dimension to be False + block_mask[:, :, :, 0] = False + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2): + block_M = 128 + block_N = 128 + block_K = 32 + num_stages = 2 + thread_num = 128 + enable_rasteration = True + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + func = blocksparse_matmul_shared( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + ) + kernel = tilelang.compile( + func, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim) + # random set the last dimension to be False + block_mask[:, :, :, 0] = False + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2): + block_M = 128 + block_N = 128 + block_K = 32 + num_stages = 2 + thread_num = 128 + enable_rasteration = True + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + func = blocksparse_matmul_local( + M, + N, + K, + condition_dim, + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasteration, + ) + kernel = tilelang.compile( + func, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim) + # random set the last dimension to be False + block_mask[:, :, :, 0] = False + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def test_block_sparse_matmul_global(): + run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2) + + +def test_block_sparse_matmul_shared(): + run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2) + + +def test_block_sparse_matmul_local(): + run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_assume.py b/tilelang/original/testing/python/language/test_tilelang_language_assume.py new file mode 100644 index 0000000000000000000000000000000000000000..06e92dfa998c99009f641325d5811aa90af667a3 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_assume.py @@ -0,0 +1,86 @@ +import tilelang +import tilelang.language as T +import tilelang.testing + + +def test_assume_remove_boundary_check(): + @tilelang.jit + def kernel_with_assume(): + N = T.dynamic("N") + + @T.prim_func + def main(A: T.Tensor((N,), T.float32), l: T.int32, r: T.int32): + with T.Kernel(1, threads=32) as _: + for i in T.serial(r - l + 1): + T.assume(l + i >= 0 and l + i < N) + A[l + i] = 0 + + return main + + jit_kernel = kernel_with_assume() + source = jit_kernel.get_kernel_source() + + assert "if (" not in source + + +def test_assume_enable_vectorization(): + @tilelang.jit + def kernel_vectorize(M): + N = T.dynamic("N") + vectorize_size = 4 + + @T.prim_func + def main( + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), + ): + with T.Kernel(1, threads=32) as _: + tid = T.get_thread_binding() + + base_idx = tid * 4 + T.assume(N % vectorize_size == 0) + + for i in T.vectorized(vectorize_size): + T.assume(base_idx + i < N) + B[tid, base_idx + i] = A[tid, base_idx + i] + + return main + + jit_kernel = kernel_vectorize(128) + source = jit_kernel.get_kernel_source() + + assert ("float4" in source) and ("if (" not in source) + + +def test_assume_complex_indexing(): + @tilelang.jit + def kernel_complex(): + M = T.dynamic("M") + N = T.dynamic("N") + + @T.prim_func + def main( + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), + ): + with T.Kernel(1, threads=32) as _: + tid = T.get_thread_binding() + for j in T.serial(N): + i_src = T.min(j + 233, tid + 2) + j_src = j * T.ceildiv(j, i_src) * j - 1 + + T.assume(i_src >= 0 and i_src < M) + T.assume(j_src >= 0 and j_src < N) + + B[tid, j] = A[i_src, j_src] + + return main + + jit_kernel = kernel_complex() + source = jit_kernel.get_kernel_source() + + assert "if (" not in source + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_atomic_add.py b/tilelang/original/testing/python/language/test_tilelang_language_atomic_add.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4dff7b38b34c0ce56a7569c8f473521ab48f13 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_atomic_add.py @@ -0,0 +1,365 @@ +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit +def atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) + + return atomic_add + + +def run_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] += A[k, i, j] + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def tile_atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + T.atomic_add(B[bx * block_M, by * block_N], A_shared) + + return atomic_add + + +def run_tile_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): + kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) + print(kernel.get_kernel_source()) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] += A[k, i, j] + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + print(B) + print(ref_B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) + + return atomic_max + + +def run_atomic_max(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] = max(B[i, j], A[k, i, j]) + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) + + return atomic_min + + +def run_atomic_min(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] = min(B[i, j], A[k, i, j]) + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_load_store_program(M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + idx_i = bx * block_M + i + idx_j = by * block_N + j + if idx_i < M and idx_j < N: + val = T.atomic_load(A[idx_i, idx_j]) + T.atomic_store(B[idx_i, idx_j], val) + + return atomic_load_store + + +def run_atomic_load_store(M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype) + import torch + + A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + kernel(A, B) + torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_memory_order_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed") + + return atomic_with_memory_order + + +def run_atomic_memory_order(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_memory_order_program(K, M, N, block_M, block_N, dtype=dtype) + import torch + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] += A[k, i, j] + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_addx2_program(M, N, block_M, block_N): + @T.prim_func + def atomic_addx2(A: T.Tensor((M, N), T.float16), B: T.Tensor((M, N), T.float16)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + for i, j in T.Parallel(block_M, block_N // 2): + idx_i = bx * block_M + i + idx_j = by * block_N + j * 2 + T.atomic_addx2(B[idx_i, idx_j], A[idx_i, idx_j]) + + return atomic_addx2 + + +def run_atomic_addx2(M, N, block_M, block_N): + kernel = atomic_addx2_program(M, N, block_M, block_N) + import torch + + A = torch.randn(M, N, dtype=torch.float16).cuda() + B = torch.zeros(M, N, dtype=torch.float16).cuda() + ref_B = B.clone() + + for i in range(M): + for j in range(0, N - 1, 2): + ref_B[i, j] += A[i, j] + ref_B[i, j + 1] += A[i, j + 1] + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +def test_atomic_add(): + run_atomic_add(8, 128, 128, 32, 32) + + +def test_atomic_max(): + run_atomic_max(4, 64, 64, 16, 16) + + +def test_atomic_min(): + run_atomic_min(4, 64, 64, 16, 16) + + +def test_atomic_load_store(): + run_atomic_load_store(64, 64, 16, 16) + + +def test_atomic_memory_order(): + run_atomic_memory_order(4, 64, 64, 16, 16) + + +def test_atomic_addx2(): + run_atomic_addx2(32, 64, 8, 16) + + +@tilelang.jit +def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_different_orders( + A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype) + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + idx_i = bx * block_M + i + idx_j = by * block_N + j + if idx_i < M and idx_j < N: + val = A[idx_i, idx_j] + T.atomic_add(B[idx_i, idx_j], val, memory_order="release") + T.atomic_max(C[idx_i, idx_j], val, memory_order="relaxed") + T.atomic_min(D[idx_i, idx_j], val, memory_order="relaxed") + + return atomic_different_orders + + +def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=dtype) + import torch + + A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + D = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() + + kernel(A, B, C, D) + + torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A)) + torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float("inf")), A)) + + +@tilelang.jit +def atomic_addx4_program(M, N, block_M, block_N): + @T.prim_func + def atomic_addx4(A: T.Tensor((M, N), T.float32), B: T.Tensor((M, N), T.float32)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + for i, j in T.Parallel(block_M, block_N // 4): + idx_i = bx * block_M + i + idx_j = by * block_N + j * 4 + T.atomic_addx4(B[idx_i, idx_j], A[idx_i, idx_j]) + + return atomic_addx4 + + +def run_atomic_addx4(M, N, block_M, block_N): + kernel = atomic_addx4_program(M, N, block_M, block_N) + import torch + + A = torch.randn(M, N, dtype=torch.float32).cuda() + B = torch.zeros(M, N, dtype=torch.float32).cuda() + ref_B = B.clone() + + for i in range(M): + for j in range(0, N - 3, 4): + ref_B[i, j] += A[i, j] + ref_B[i, j + 1] += A[i, j + 1] + ref_B[i, j + 2] += A[i, j + 2] + ref_B[i, j + 3] += A[i, j + 3] + + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +@tilelang.jit +def atomic_return_prev_program(M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + idx_i = bx * block_M + i + idx_j = by * block_N + j + if idx_i < M and idx_j < N: + old_vals[idx_i, idx_j] = T.atomic_add(B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True) + + return atomic_with_return_prev + + +def run_atomic_return_prev(M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_return_prev_program(M, N, block_M, block_N, dtype=dtype) + import torch + + A = torch.ones(M, N, dtype=getattr(torch, dtype)).cuda() * 5.0 + B = torch.ones(M, N, dtype=getattr(torch, dtype)).cuda() * 2.0 + old_vals = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() + + initial_B = B.clone() + kernel(A, B, old_vals) + + torch.testing.assert_close(old_vals, initial_B, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(B, initial_B + A, atol=1e-3, rtol=1e-3) + + +def test_atomic_different_memory_orders(): + run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float32) + run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float16) + run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.bfloat16) + + +def test_atomic_addx4(): + run_atomic_addx4(16, 64, 4, 4) + + +def test_atomic_return_prev(): + run_atomic_return_prev(32, 32, 8, 8) + + +def test_tile_atomic_add(): + run_tile_atomic_add(8, 128, 128, 32, 32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_ceildiv.py b/tilelang/original/testing/python/language/test_tilelang_language_ceildiv.py new file mode 100644 index 0000000000000000000000000000000000000000..f5af31b83971de1fd425e6aca2d29b043990b75f --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_ceildiv.py @@ -0,0 +1,57 @@ +import tilelang.language as T +import tilelang.testing +import torch + + +@tilelang.jit(out_idx=[-1]) +def _ceildiv_kernel(a: int, b: int): + @T.prim_func + def ceildiv_kernel(A: T.Tensor((1,), T.int32)): + with T.Kernel(1, threads=1) as _: + A[0] = T.ceildiv(T.int32(a), T.int32(b)) + + return ceildiv_kernel + + +def run_ceildiv(a=128, b=32): + kernel = _ceildiv_kernel(a, b) + A = kernel() + print(kernel.get_kernel_source()) + print(A) + + +def test_ceildiv(): + run_ceildiv(a=128, b=32) + run_ceildiv(a=1, b=32) + run_ceildiv(a=-1, b=32) + run_ceildiv(a=-2, b=32) + + +@tilelang.jit +def _ceildiv_kernel_dyn(b: int): + @T.prim_func + def ceildiv_kernel(A: T.Tensor((1,), T.int32), a: T.int32): + with T.Kernel(1, threads=1) as _: + A[0] = T.ceildiv(T.int32(a), T.int32(b)) + + return ceildiv_kernel + + +def run_ceildiv_dyn(a=128, b=32): + kernel = _ceildiv_kernel_dyn(b) + A = torch.empty((1,), dtype=torch.int32, device="cuda") + kernel(A, a) + print(kernel.get_kernel_source()) + print(A) + + +@tilelang.testing.requires_cuda +def test_ceildiv_dyn(): + run_ceildiv_dyn(a=128, b=32) + run_ceildiv_dyn(a=1, b=32) + run_ceildiv_dyn(a=-1, b=32) + run_ceildiv_dyn(a=-2, b=32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_chain_equal.py b/tilelang/original/testing/python/language/test_tilelang_language_chain_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..083eefdcb493de9e77ed332db3e276adbe0e4008 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_chain_equal.py @@ -0,0 +1,46 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import torch + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def chain_equal(N, block_size, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx: + for lane in T.Parallel(block_size): + idx = bx * block_size + lane + A[idx] = B[idx] = C[idx] = 1 + + return main + + +def run_chain_equal(N=128, block_size=64, dtype=T.float32): + kernel = chain_equal(N, block_size, dtype) + A = torch.zeros((N,), dtype=torch.float32, device="cuda") + B = torch.zeros((N,), dtype=torch.float32, device="cuda") + C = torch.zeros((N,), dtype=torch.float32, device="cuda") + kernel(A, B, C) + ref = torch.ones_like(A) + torch.testing.assert_close(A, ref) + torch.testing.assert_close(B, ref) + torch.testing.assert_close(C, ref) + + +@tilelang.testing.requires_cuda +def test_chain_equal(): + run_chain_equal() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_clamp.py b/tilelang/original/testing/python/language/test_tilelang_language_clamp.py new file mode 100644 index 0000000000000000000000000000000000000000..372d7478468c8890aa8ab56152414952b695c0f5 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_clamp.py @@ -0,0 +1,117 @@ +import tilelang.testing +from tilelang import language as T + + +def clamp_within_bounds( + N, + block_N, + dtype, + min_val=None, + max_val=None, +): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared([block_N], dtype) + T.copy(A[bx * block_N], A_shared) + for i in T.Parallel(block_N): + A_shared[i] = T.clamp(A_shared[i], min_val=min_val, max_val=max_val) + T.copy(A_shared, B[bx * block_N]) + + return main + + +def run_clamp( + N, + block_N, + dtype, + min=None, + max=None, +): + program = clamp_within_bounds(N, block_N, dtype, min, max) + + kernel = tilelang.compile(program, out_idx=[1]) + profiler = kernel.get_profiler() + + def ref_program(A): + import torch + + output = torch.clamp(A, min, max) + return output + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def clamp_value_range( + N, + block_N, + dtype, +): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((1, N), dtype), + B: T.Tensor((1, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + # A_shared = T.alloc_shared([1, block_N], dtype=dtype) + A_frag = T.alloc_fragment([1, block_N], dtype=dtype) + min_frag = T.alloc_fragment([1], dtype=dtype) + max_frag = T.alloc_fragment([1], dtype=dtype) + T.copy(A[0, bx * block_N], A_frag) + T.reduce_min(A_frag, min_frag, dim=1) + T.reduce_max(A_frag, max_frag, dim=1) + for i in T.Parallel(block_N): + # A_frag[0, i] = T.max(A_frag[0, i], min_frag[0] * 0.5) + # A_frag[0, i] = T.min(A_frag[0, i], max_frag[0] * 0.5) + A_frag[0, i] = T.clamp(A_frag[0, i], min_frag[0] * 0.5, max_frag[0] * 0.5) + T.copy(A_frag, B[0, bx * block_N]) + + return main + + +def run_clamp_value_range( + N, + block_N, + dtype, +): + program = clamp_value_range( + N, + block_N, + dtype, + ) + kernel = tilelang.compile(program, out_idx=[1]) + + import torch + + # Convert string dtype to torch.dtype + torch_dtype = dtype.as_torch() + + def ref_program(A): + min_val = torch.min(A) * 0.5 + max_val = torch.max(A) * 0.5 + output = torch.clamp(A, min_val, max_val) + return output + + A = torch.randint(-5, 5, (1, N)).cuda().to(dtype=torch_dtype) + B = kernel(A) + ref_b = ref_program(A) + torch.testing.assert_close(B, ref_b) + + +def test_clamp(): + # clamp tests for float16 and float32 + run_clamp(1024, 128, T.float16, -0.05, 0.05) + run_clamp(1024, 128, T.float32, -0.06, 0.05) + run_clamp_value_range(1024, 128, T.float16) + run_clamp_value_range(1024, 128, T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_clear.py b/tilelang/original/testing/python/language/test_tilelang_language_clear.py new file mode 100644 index 0000000000000000000000000000000000000000..af9d89631f725985adad445db2ef919d29c784eb --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_clear.py @@ -0,0 +1,59 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + T.clear(A_shared) + + # Demonstrate parallelized copy from global to shared for B + T.copy(B[bx * block_N, ko * block_K], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) + kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) + import torch + from tilelang.utils import map_torch_type + + a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda() + b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda() + c = kernel(a, b) + assert torch.allclose(c, torch.zeros_like(c)) + + +def test_matmul(): + run_matmul(1024, 1024, 1024, 128, 128, 32) + + +if __name__ == "__main__": + test_matmul() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_composable_index.py b/tilelang/original/testing/python/language/test_tilelang_language_composable_index.py new file mode 100644 index 0000000000000000000000000000000000000000..7893c1f2438a02b00bf9d4b3d71b6464c4ec6c6e --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_composable_index.py @@ -0,0 +1,51 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import torch + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_composable_copy(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M * N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], dtype) + B_local = T.alloc_fragment([block_M * block_N], dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + for i, j in T.Parallel(block_M, block_N): + B_local[i * block_N + j] = A_local[i, j] + for i in T.Parallel(block_M * block_N): + B[by * block_M * N + bx * block_N + i // block_N * N + i % block_N] = B_local[i] + + return main + + +def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_composable_copy(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy(): + run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128) + run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576) + run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_copy.py b/tilelang/original/testing/python/language/test_tilelang_language_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..29bb0f9514851ef5a696e27ab349d300d49aca0d --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_copy.py @@ -0,0 +1,188 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing + +print(torch.__version__) + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_copy(M, N, block_M, block_N, src_dtype=T.float16, dst_dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), src_dtype), + B: T.Tensor((M, N), dst_dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + T.copy( + A[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], + B[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], + ) + + return main + + +def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_copy(M, N, block_M, block_N, src_dtype=dtype, dst_dtype=dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + target="cuda", + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) + source = kernel.get_kernel_source() + print(source) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy(): + run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128) + run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576) + run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32) + + +def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.StridedTensor((M, N), (NN, 1), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] + + return main + + +def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype=T.float16): + if isinstance(NN, int): + assert NN > N, "NN must be greater than N" + program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, + ) + if isinstance(NN, T.Var): + NN = N * 2 + a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a[:, :N]) + torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_with_stride(): + run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128) + run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128) + + +def tilelang_copy_bufferload(num_tokens, dtype=T.float16): + @T.prim_func + def main( + indices: T.Tensor((num_tokens,), T.int32), + x: T.Tensor((num_tokens,), dtype), + ): + with T.Kernel(num_tokens, threads=32) as pid: + idx = T.alloc_local([1], T.int32) + T.copy(indices[pid], idx[0]) + x[idx[0]] = x[idx[0]] + 1 + + return main + + +def run_tilelang_copy_bufferload(num_tokens=128, dtype=T.float16): + program = tilelang_copy_bufferload(num_tokens, dtype) + # test compilation only + tilelang.compile( + program, + out_idx=[1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) + + +def test_tilelang_copy_bufferload(): + run_tilelang_copy_bufferload(num_tokens=128) + + +def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + T.copy(A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]) + + return main + + +def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + target="cuda", + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_buffer_load_with_parallel(): + run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128) + + +def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu): + program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + ) + source = kernel.get_kernel_source() + assert "fp8_e8_t" in source + dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8).view(torch.float8_e8m0fnu) + output = kernel(dummy_input) + assert output is not None + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(10, 0) +def test_tilelang_copy_fp8_e8m0(): + run_tilelang_copy_fp8_e8m0(src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu) + + +def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn): + program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + ) + source = kernel.get_kernel_source() + assert "fp4_e2_t" in source + # For FP4, use same shape as kernel expects, since int8 is used as storage type + dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8) + output = kernel(dummy_input) + assert output is not None + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(10, 0) +def test_tilelang_copy_fp4(): + run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn) + run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float16) + run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.bfloat16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_cumsum.py b/tilelang/original/testing/python/language/test_tilelang_language_cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..fecc0d2a88b40f1d94e9909add3f749699af0a5f --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_cumsum.py @@ -0,0 +1,311 @@ +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl +import torch +import tilelang.language as T + + +def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): + @T.prim_func + def cumsum( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.cumsum(src=A_shared, dim=dim, reverse=reverse) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return cumsum + + +def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + A_fragment = T.alloc_fragment((block_M, block_N), dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(A_shared, A_fragment) + T.cumsum(src=A_fragment, dim=dim, reverse=reverse) + T.copy(A_fragment, B[by * block_M, bx * block_N]) + + return cumsum + + +def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32, scope="smem"): + if scope == "smem": + program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype) + elif scope == "fragment": + program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + + A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() + + def ref_program(A): + ref_b = torch.empty_like(A) + for i in range(M // block_M): + for j in range(N // block_N): + ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = A[ + i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N + ].cumsum(dim=dim) + if reverse: + ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = ( + A[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] + .flip(dims=[dim]) + .cumsum(dim=dim) + .flip(dims=[dim]) + ) + return ref_b + + tilelang_res = jit_kernel(A) + ref_res = ref_program(A) + torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) + + +def cumsum_smem_test_1d(N, block_N, reverse=False, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared((block_N,), dtype) + + T.copy(A[bx * block_N], A_shared) + T.cumsum(src=A_shared, dim=0, reverse=reverse) + T.copy(A_shared, B[bx * block_N]) + + return cumsum + + +def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: + A_shared = T.alloc_shared((block_N,), dtype) + A_fragment = T.alloc_fragment((block_N,), dtype) + + T.copy(A[bx * block_N], A_shared) + T.copy(A_shared, A_fragment) + T.cumsum(src=A_fragment, dim=0, reverse=reverse) + T.copy(A_fragment, B[bx * block_N]) + + return cumsum + + +def run_cumsum_1d(N, block_N, reverse=False, dtype=T.float32, scope="smem"): + if scope == "smem": + program = cumsum_smem_test_1d(N, block_N, reverse, dtype) + elif scope == "fragment": + program = cumsum_fragment_test_1d(N, block_N, reverse, dtype) + else: + raise ValueError(f"Unknown scope {scope}") + + jit_kernel = tl.compile(program, out_idx=-1) + A = torch.randn(N, dtype=getattr(torch, dtype)).cuda() + + def ref_program(A): + ref_b = torch.empty_like(A) + num_blocks = (N + block_N - 1) // block_N + for j in range(num_blocks): + start = j * block_N + end = min(start + block_N, N) + chunk = A[start:end] + if reverse: + chunk = torch.flip(chunk, dims=[0]) + chunk = chunk.cumsum(dim=0) + if reverse: + chunk = torch.flip(chunk, dims=[0]) + ref_b[start:end] = chunk + return ref_b + + tilelang_res = jit_kernel(A) + ref_res = ref_program(A) + torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) + + +def test_cumsum_smem(): + # Test different sizes + run_cumsum(1024, 1024, 128, 128) + run_cumsum(1024, 1024, 128, 128, dim=1) + run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True) + + # Test different dtypes + run_cumsum(256, 256, 128, 128, dtype=T.float32) + run_cumsum(256, 256, 128, 128, dtype=T.float32) + + +def test_cumsum_fragment(): + run_cumsum(1024, 1024, 128, 128, scope="fragment") + run_cumsum(1024, 1024, 128, 128, dim=1, scope="fragment") + run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment") + + # Test different dtypes + run_cumsum(256, 256, 128, 128, dtype=T.float32, scope="fragment") + run_cumsum(256, 256, 128, 128, dtype=T.float32, scope="fragment") + + +def test_cumsum_smem_1d(): + run_cumsum_1d(1024, 128) + run_cumsum_1d(1024, 128, reverse=True) + + +def test_cumsum_fragment_1d(): + run_cumsum_1d(1024, 128, scope="fragment") + run_cumsum_1d(1024, 128, reverse=True, scope="fragment") + + +def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype=T.float32): + """Test cumsum with buffer region (slice) as input.""" + import tilelang.language as T + + @T.prim_func + def cumsum_region( + InputG_fragment: T.Tensor((N,), dtype), + OutputG_fragment: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, chunk_size), threads=chunk_size) as bx: + i = bx + chunk_start = i * chunk_size + # Copy region to shared memory first (cumsum only supports shared memory) + A_shared = T.alloc_shared((chunk_size,), dtype) + T.copy(InputG_fragment[chunk_start : chunk_start + chunk_size], A_shared) + # Test cumsum with region input - in-place operation on shared memory + # This demonstrates the feature: T.cumsum(region, dim=0) + T.cumsum(src=A_shared, dim=0, reverse=reverse) + # Copy result back to global memory + T.copy(A_shared, OutputG_fragment[chunk_start : chunk_start + chunk_size]) + + return cumsum_region + + +def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype=T.float32): + """Run test for cumsum with region input.""" + program = cumsum_region_test_1d(N, chunk_size, reverse, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + A = torch.randn(N, dtype=getattr(torch, dtype)).cuda() + + def ref_program(A): + ref_b = torch.empty_like(A) + num_blocks = (N + chunk_size - 1) // chunk_size + for j in range(num_blocks): + start = j * chunk_size + end = min(start + chunk_size, N) + chunk = A[start:end].clone() + if reverse: + chunk = torch.flip(chunk, dims=[0]) + chunk = chunk.cumsum(dim=0) + if reverse: + chunk = torch.flip(chunk, dims=[0]) + ref_b[start:end] = chunk + return ref_b + + tilelang_res = jit_kernel(A) + ref_res = ref_program(A) + torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) + + +def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): + """Test cumsum with buffer region (slice) as input in 2D.""" + import tilelang.language as T + + @T.prim_func + def cumsum_region( + InputG_fragment: T.Tensor((M, N), dtype), + OutputG_fragment: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + chunk_start_M = by * block_M + chunk_start_N = bx * block_N + # Copy region to shared memory first (cumsum only supports shared memory) + A_shared = T.alloc_shared((block_M, block_N), dtype) + T.copy( + InputG_fragment[chunk_start_M : chunk_start_M + block_M, chunk_start_N : chunk_start_N + block_N], + A_shared, + ) + # Test cumsum with 2D region input - in-place operation on shared memory + T.cumsum(src=A_shared, dim=dim, reverse=reverse) + # Copy result back to global memory + T.copy( + A_shared, + OutputG_fragment[chunk_start_M : chunk_start_M + block_M, chunk_start_N : chunk_start_N + block_N], + ) + + return cumsum_region + + +def run_cumsum_region_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): + """Run test for cumsum with 2D region input.""" + program = cumsum_region_test_2d(M, N, block_M, block_N, dim, reverse, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() + + def ref_program(A): + ref_b = torch.empty_like(A) + num_blocks_M = (M + block_M - 1) // block_M + num_blocks_N = (N + block_N - 1) // block_N + for i in range(num_blocks_M): + for j in range(num_blocks_N): + start_M = i * block_M + end_M = min(start_M + block_M, M) + start_N = j * block_N + end_N = min(start_N + block_N, N) + chunk = A[start_M:end_M, start_N:end_N].clone() + if reverse: + chunk = torch.flip(chunk, dims=[dim]) + chunk = chunk.cumsum(dim=dim) + if reverse: + chunk = torch.flip(chunk, dims=[dim]) + ref_b[start_M:end_M, start_N:end_N] = chunk + return ref_b + + tilelang_res = jit_kernel(A) + ref_res = ref_program(A) + torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) + + +def test_cumsum_region_1d(): + """Test cumsum with 1D region input.""" + # Test normal cumsum with region input + run_cumsum_region_1d(1024, 128) + # Test reverse cumsum with region input + run_cumsum_region_1d(1024, 128, reverse=True) + # Test with different chunk sizes + run_cumsum_region_1d(512, 64) + run_cumsum_region_1d(2048, 256) + # Tail coverage (non-divisible size) + run_cumsum_region_1d(1000, 128) + + +def test_cumsum_region_2d(): + """Test cumsum with 2D region input.""" + # Test 2D cumsum along dim 0 + run_cumsum_region_2d(1024, 1024, 128, 128, dim=0) + # Test 2D cumsum along dim 1 + run_cumsum_region_2d(1024, 1024, 128, 128, dim=1) + # Test reverse cumsum + run_cumsum_region_2d(512, 512, 64, 64, dim=1, reverse=True) + # Tail coverage (non-divisible size) + run_cumsum_region_2d(1000, 1000, 128, 128, dim=1) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_frontend_v2.py b/tilelang/original/testing/python/language/test_tilelang_language_frontend_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..67115e8c2fd6b921f7eaa5f13996a0ef2d8d775d --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_frontend_v2.py @@ -0,0 +1,474 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import tvm +from tvm.script.ir_builder.base import IRBuilderFrame +from tvm.tir.expr import IntImm, Var + + +def test_argument(): + @T.prim_func + def test_argument( + t_1: T.bool, + t_2: T.short, + t_3: T.int, + t_4: T.long, + t_5: T.half, + t_6: T.float, + t_7: T.long, + t_8: T.int8, + t_9: T.int16, + t_10: T.int32, + t_11: T.int64, + t_12: T.uint8, + t_13: T.uint16, + t_14: T.uint32, + t_15: T.uint64, + t_16: T.float8_e4m3fn, + t_17: T.float8_e4m3fnuz, + t_18: T.float8_e5m2, + t_19: T.float8_e5m2fnuz, + t_20: T.float8_e8m0fnu, + t_21: T.float16, + t_22: T.bfloat16, + t_23: T.float32, + t_24: T.float64, + ): + pass + + +def test_expr(): + from tilelang.language.v2.dtypes import _all_dtypes + + errors = [] + for name in _all_dtypes: + dtype = getattr(T, name) + assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType" + try: + dtype(1.0) + dtype() + except TypeError: + pass + except Exception: + errors.append(name) + assert not errors + + +# def test_var_decl_sugar(): + +# @T.prim_func +# def test_var_decl_sugar(): +# with T.Kernel(128, 128) as (bx, by): +# var_1: T.bool = 1.0 +# var_2: T.short = 1.0 +# var_3: T.int = 1.0 +# var_4: T.long = 1.0 +# var_5: T.half = 1.0 +# var_6: T.float = 1.0 +# var_7: T.long = 1.0 +# var_8: T.int8 = 1.0 +# var_9: T.int16 = 1.0 +# var_10: T.int32 = 1.0 +# var_11: T.int64 = 1.0 +# var_12: T.uint8 = 1.0 +# var_13: T.uint16 = 1.0 +# var_14: T.uint32 = 1.0 +# var_15: T.uint64 = 1.0 +# var_16: T.float8_e4m3fn = 1.0 +# var_17: T.float8_e4m3fnuz = 1.0 +# var_18: T.float8_e5m2 = 1.0 +# var_19: T.float8_e5m2fnuz = 1.0 +# var_20: T.float8_e8m0fnu = 1.0 +# var_21: T.float16 = 1.0 +# var_22: T.bfloat16 = 1.0 +# var_23: T.float32 = 1.0 +# var_24: T.float64 = 1.0 +# var_1: T.bool = var_1 +# var_2: T.short = var_2 +# var_3: T.int = var_3 +# var_4: T.long = var_4 +# var_5: T.half = var_5 +# var_6: T.float = var_6 +# var_7: T.long = var_7 +# var_8: T.int8 = var_8 +# var_9: T.int16 = var_9 +# var_10: T.int32 = var_10 +# var_11: T.int64 = var_11 +# var_12: T.uint8 = var_12 +# var_13: T.uint16 = var_13 +# var_14: T.uint32 = var_14 +# var_15: T.uint64 = var_15 +# var_16: T.float8_e4m3fn = var_16 +# var_17: T.float8_e4m3fnuz = var_17 +# var_18: T.float8_e5m2 = var_18 +# var_19: T.float8_e5m2fnuz = var_19 +# var_20: T.float8_e8m0fnu = var_20 +# var_21: T.float16 = var_21 +# var_22: T.bfloat16 = var_22 +# var_23: T.float32 = var_23 +# var_24: T.float64 = var_24 + +# s = test_var_decl_sugar.script() +# for i in range(1, 25): +# assert f'var_{i}_1' in s +# assert 'tl.local_var_init' in s + + +def test_dtype_str_repr(): + @T.prim_func + def test_str_repr(): + buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope="shared") # noqa F841 + buf_2 = T.alloc_buffer((1,), dtype=T.short, scope="shared") # noqa F841 + buf_3 = T.alloc_buffer((1,), dtype=T.int, scope="shared") # noqa F841 + buf_4 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841 + buf_5 = T.alloc_buffer((1,), dtype=T.half, scope="shared") # noqa F841 + buf_6 = T.alloc_buffer((1,), dtype=T.float, scope="shared") # noqa F841 + buf_7 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841 + buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope="shared") # noqa F841 + buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope="shared") # noqa F841 + buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope="shared") # noqa F841 + buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope="shared") # noqa F841 + buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope="shared") # noqa F841 + buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope="shared") # noqa F841 + buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope="shared") # noqa F841 + buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope="shared") # noqa F841 + buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope="shared") # noqa F841 + buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope="shared") # noqa F841 + buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope="shared") # noqa F841 + buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope="shared") # noqa F841 + buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope="shared") # noqa F841 + buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope="shared") # noqa F841 + buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope="shared") # noqa F841 + buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope="shared") # noqa F841 + buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope="shared") # noqa F841 + + +# not supported now +# def test_torch_eq(): +# dtypes = [ +# T.bool, +# T.short, +# T.int, +# T.long, +# T.half, +# T.float, +# T.long, +# T.int8, +# T.int16, +# T.int32, +# T.int64, +# T.uint8, +# T.uint16, +# T.uint32, +# T.uint64, +# T.float8_e4m3fn, +# T.float8_e4m3fnuz, +# T.float8_e5m2, +# T.float8_e5m2fnuz, +# T.float8_e8m0fnu, +# T.float16, +# T.bfloat16, +# T.float32, +# T.float64, +# ] +# torch_dtypes = [ +# torch.bool, +# torch.short, +# torch.int, +# torch.long, +# torch.half, +# torch.float, +# torch.long, +# torch.int8, +# torch.int16, +# torch.int32, +# torch.int64, +# torch.uint8, +# torch.uint16, +# torch.uint32, +# torch.uint64, +# torch.float8_e4m3fn, +# torch.float8_e4m3fnuz, +# torch.float8_e5m2, +# torch.float8_e5m2fnuz, +# torch.float8_e8m0fnu, +# torch.float16, +# torch.bfloat16, +# torch.float32, +# torch.float64, +# ] +# for a, b in zip(dtypes, torch_dtypes): +# assert a == b, f"{a} and {b} are not equal" +# assert T.dtype(b) == a, "dtype conversion error" + + +def test_var_assign(): + @tilelang.jit(out_idx=-1) + @T.prim_func + def test_var_assign(A: T.Tensor((2,), T.int32)): + with T.Kernel(1) as _: + a: T.int32 = 1 + b: T.int32 = a + a = 2 + d: T.int32 = a + A[0] = b + A[1] = d + + res = test_var_assign()() + assert res[0] == 1 + assert res[1] == 2 + + +def test_marco_return(): + @T.macro + def macro_return_constant(): + return 0 + + @T.macro + def macro_return_frame(x): + return T.alloc_var(T.float32, init=x) + + @T.macro + def macro_return_expr(x): + y = x + 1.0 + return y + + @T.macro + def macro_apply_func(x, fn): + return fn(x) + + def check(x, ty): + assert isinstance(x, ty) + + @T.prim_func + def test_macro_return(): + with T.Kernel(1) as _: + a = macro_return_constant() + b = macro_return_frame(3.0) + c = macro_return_expr(4.0) + d = macro_apply_func(5.0, lambda x: x * 2.0) + check(a, (int, float, T.PrimExpr)) + check(b, (int, float, T.PrimExpr)) + check(c, (int, float, T.PrimExpr)) + check(d, (int, float, T.PrimExpr)) + + +def test_prim_func_generator(): + @T.prim_func(generator=True) + def prim_func_gen( + A=T.Tensor((128,), T.float32), # noqa: B008 + B=T.Tensor((128,), T.float32), # noqa: B008 + ): + with T.Kernel(128) as (tx,): + T.copy(A[tx], B[tx]) + + prim_func_gen() + + @T.prim_func + def foo() -> T.Tensor((128,), T.float32): + pass + + assert isinstance(foo, T.PrimFunc) + + +def test_serial_for_with_step(): + @tilelang.jit(out_idx=-1) + @T.prim_func + def test_stepped_serial(A: T.Tensor((10,), T.int32)): + with T.Kernel(1) as _: + for i in range(0, 10, 2): + T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range") + A[i] = 1.0 + for i in range(1, 10, 2): + T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range") + A[i] = 2.0 + + ker = test_stepped_serial() + res = ker() + ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda") + assert torch.all(res == ref), f"Expected {ref}, but got {res}" + + @tilelang.jit(out_idx=-1) + @T.prim_func + def test_serial_step_neg(A: T.Tensor((10,), T.int32)): + with T.Kernel(1) as _: + for i in range(10, 0, -1): + T.device_assert(0 < i <= 10, "i out of range") + A[10 - i] = i + + ker = test_serial_step_neg() + res = ker() + ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda") + assert torch.all(res == ref), f"Expected {ref}, but got {res}" + + assert isinstance(T.serial(1, 10, 1), IRBuilderFrame) + assert isinstance(T.serial(1, 10, IntImm(T.int32, 1)), IRBuilderFrame) + assert not isinstance(T.serial(1, 10, Var("tmp", T.int32)), IRBuilderFrame) + assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame) + + +def test_swap_logic(): + @tilelang.jit + @T.prim_func + def swap_var(A: T.Tensor[(2,), T.float32]): + with T.Kernel(1, threads=1) as _: + a = T.alloc_var(T.float32, A[0]) + b = T.alloc_var(T.float32, A[1]) + a, b = b, a + A[0], A[1] = a, b + + @tilelang.jit + @T.prim_func + def swap_idx(A: T.Tensor[(2,), T.float32]): + with T.Kernel(1, threads=1) as _: + A[0], A[1] = A[1], A[0] + + k_swap_var = swap_var() + data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() + k_swap_var(data) + ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() + torch.testing.assert_close(data, ref) + + k_swap_idx = swap_idx() + data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() + k_swap_idx(data) + ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() + torch.testing.assert_close(data, ref) + + +def test_while_loop(): + @tilelang.jit(out_idx=-1) + @T.prim_func + def test_while_loop(A: T.Tensor((1,), T.int32)): + with T.Kernel(1) as _: + i = T.alloc_var(T.int32, 0) + sum = T.alloc_var(T.int32) + while i < 10: + sum += i + i += 1 + A[0] = sum + + ker = test_while_loop() + A = ker() + assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}" + + +def test_var_macro(): + try: + + @T.macro + def macro_with_var(x: T.Var): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = T.alloc_var(T.int32) + macro_with_var(x) + + assert "x[0] = 1" in prim_call_macro.script() + finally: + pass + + try: + + @T.macro + def macro_with_var(x: T.Var): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = 1 + macro_with_var(x) + + raise RuntimeError("Expect to report an error, x should not be passed as T.Var") + except ValueError: + pass + + try: + + @T.macro + def macro_with_var(x: T.Ref): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = T.alloc_var(T.int32) + macro_with_var(x) + + assert "x[0] = 1" in prim_call_macro.script() + finally: + pass + + try: + + @T.macro + def macro_with_var(x: T.Ref): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = 1 + macro_with_var(x) + + raise RuntimeError("Expect to report an error, x should not be passed as T.Var") + except ValueError: + pass + + +def test_frame_inside_macro(): + @tilelang.jit + def get_sample_kernel(): + @T.macro + def transform(x): + return x + 1 + + @T.prim_func + def sample_kernel( + num_blocks: T.int32, + idx_out: T.Tensor[(32,), T.int32], + ): + with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 + fragment = T.alloc_fragment(32, T.int32) + T.copy(idx_out, fragment) + + for i in T.Parallel(32): + idx_out[i] = transform(fragment[i]) + + return sample_kernel + + kernel = get_sample_kernel() # noqa: F841 + + +def test_buffer_slice_step(): + try: + + @T.prim_func + def prim_buffer_slice_step(A: T.Buffer((10,), T.int32), B: T.Buffer((5,), T.int32)): + with T.Kernel(1): + B[0:5:2] = A[0:10:2] + + raise AssertionError("Expect to report an error, buffer slice with step is not supported") + except RuntimeError: + pass + + +def test_boolop(): + a = Var("a", T.int32) + b = Var("b", T.int32) + c = Var("c", T.int32) + d = Var("d", T.int32) + + @T.macro + def cond(): + return not (a < b and b < c and a * d < b * d) or b * d < c * d + + cond() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_get_warp_info.py b/tilelang/original/testing/python/language/test_tilelang_language_get_warp_info.py new file mode 100644 index 0000000000000000000000000000000000000000..e14cece9831472a992dbac77177041898e4ced7b --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_get_warp_info.py @@ -0,0 +1,206 @@ +from typing import Optional + +import tilelang.language as T +import tilelang.testing +import torch +from tilelang.utils.target import check_hip_availability + +_IS_HIP_AVAILABLE = check_hip_availability() +_DEFAULT_WARPS_PER_GROUP = 4 + + +def _resolve_warp_size(warp_size: Optional[int]) -> int: + if warp_size is not None: + return int(warp_size) + return 64 if _IS_HIP_AVAILABLE else 32 + + +def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: + if warps_per_group is not None: + return int(warps_per_group) + return _DEFAULT_WARPS_PER_GROUP + + +@tilelang.jit(out_idx=[-1]) +def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): + @T.prim_func + def laneid_kernel(A: T.Tensor((num_threads,), T.int32)): + with T.Kernel(1, threads=num_threads) as _: + tx = T.get_thread_binding() + A[tx] = T.get_lane_idx(warp_size) + + return laneid_kernel + + +@tilelang.jit(out_idx=[-1]) +def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): + @T.prim_func + def warp_idx_sync_kernel(A: T.Tensor((num_threads,), T.int32)): + with T.Kernel(1, threads=num_threads) as _: + tx = T.get_thread_binding() + A[tx] = T.get_warp_idx_sync(warp_size) + + return warp_idx_sync_kernel + + +@tilelang.jit(out_idx=[-1]) +def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): + @T.prim_func + def warp_idx_kernel(A: T.Tensor((num_threads,), T.int32)): + with T.Kernel(1, threads=num_threads) as _: + tx = T.get_thread_binding() + A[tx] = T.get_warp_idx(warp_size) + + return warp_idx_kernel + + +@tilelang.jit(out_idx=[-1]) +def _get_warp_group_idx_kernel( + num_threads: int = 128, + warp_size: Optional[int] = None, + warps_per_group: Optional[int] = None, +): + @T.prim_func + def warp_group_idx_kernel(A: T.Tensor((num_threads,), T.int32)): + with T.Kernel(1, threads=num_threads) as _: + tx = T.get_thread_binding() + A[tx] = T.get_warp_group_idx(warp_size, warps_per_group) + + return warp_group_idx_kernel + + +@tilelang.jit(out_idx=[-1]) +def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): + @T.prim_func + def shuffle_elect_kernel(A: T.Tensor((num_threads,), T.int32)): + with T.Kernel(1, threads=num_threads) as _: + tx = T.get_thread_binding() + elected = T.shuffle_elect(thread_extent) + A[tx] = elected + + return shuffle_elect_kernel + + +def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None): + kernel = _get_laneid_kernel(num_threads, warp_size) + A = kernel() + print(kernel.get_kernel_source()) + print(A) + expected_warp_size = _resolve_warp_size(warp_size) + ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size + torch.testing.assert_close(A.cpu(), ref.cpu()) + return A + + +def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None): + kernel = _get_warp_idx_sync_kernel(num_threads, warp_size) + A = kernel() + print(kernel.get_kernel_source()) + print(A) + expected_warp_size = _resolve_warp_size(warp_size) + ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size + torch.testing.assert_close(A.cpu(), ref.cpu()) + return A + + +def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None): + kernel = _get_warp_idx_kernel(num_threads, warp_size) + A = kernel() + print(kernel.get_kernel_source()) + print(A) + expected_warp_size = _resolve_warp_size(warp_size) + ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size + torch.testing.assert_close(A.cpu(), ref.cpu()) + return A + + +def run_get_warp_group_idx( + num_threads: int = 128, + warp_size: Optional[int] = None, + warps_per_group: Optional[int] = None, +): + kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group) + A = kernel() + print(kernel.get_kernel_source()) + print(A) + expected_warp_size = _resolve_warp_size(warp_size) + expected_warps_per_group = _resolve_warps_per_group(warps_per_group) + threads_per_group = expected_warp_size * expected_warps_per_group + if threads_per_group <= 0: + raise ValueError("threads_per_group must be positive.") + ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group + torch.testing.assert_close(A.cpu(), ref.cpu()) + return A + + +def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64): + if thread_extent < 0: + raise ValueError("thread_extent must be non-negative.") + kernel = _shuffle_elect_kernel(num_threads, thread_extent) + A = kernel() + print(kernel.get_kernel_source()) + print(A) + indices = torch.arange(num_threads, device=A.device, dtype=torch.int64) + if thread_extent == 0: + mask = indices == 0 + elif thread_extent > 0: + mask = (indices % thread_extent) == 0 + else: + mask = torch.zeros_like(indices, dtype=torch.bool) + ref = mask.to(dtype=A.dtype, device=A.device) + torch.testing.assert_close(A.cpu(), ref.cpu()) + return A + + +@tilelang.testing.requires_cuda +def test_get_lane_idx_default(): + run_get_lane_id() + + +@tilelang.testing.requires_cuda +def test_get_lane_idx_custom(): + run_get_lane_id(num_threads=256, warp_size=64) + + +@tilelang.testing.requires_cuda +def test_get_warp_idx_sync_default(): + run_get_warp_idx_sync() + + +@tilelang.testing.requires_cuda +def test_get_warp_idx_sync_custom(): + run_get_warp_idx_sync(num_threads=256, warp_size=16) + + +@tilelang.testing.requires_cuda +def test_get_warp_idx_default(): + run_get_warp_idx() + + +@tilelang.testing.requires_cuda +def test_get_warp_idx_custom(): + run_get_warp_idx(num_threads=320, warp_size=20) + + +@tilelang.testing.requires_cuda +def test_get_warp_group_idx_default(): + run_get_warp_group_idx() + + +@tilelang.testing.requires_cuda +def test_get_warp_group_idx_custom(): + run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5) + + +@tilelang.testing.requires_cuda +def test_shuffle_elect_default(): + run_shuffle_elect(num_threads=256, thread_extent=64) + + +@tilelang.testing.requires_cuda +def test_shuffle_elect_block_leader(): + run_shuffle_elect(num_threads=128, thread_extent=0) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_if_range.py b/tilelang/original/testing/python/language/test_tilelang_language_if_range.py new file mode 100644 index 0000000000000000000000000000000000000000..c81a241ba12dd26eef6d5346da8441a582e6a3ab --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_if_range.py @@ -0,0 +1,53 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing + + +@tilelang.jit( + out_idx=[1], +) +def tilelang_if_range(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + row_idx = by * block_M + i + col_idx = bx * block_N + j + # Test condition: ca < i < cb where ca=16, cb=96 + if 16 < row_idx < 96: + B[row_idx, col_idx] = A[row_idx, col_idx] * 2.0 + else: + B[row_idx, col_idx] = A[row_idx, col_idx] * 0.5 + + return main + + +def run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32, dtype=T.float16): + kernel = tilelang_if_range(M, N, block_M, block_N, dtype) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + + # Reference computation + ref_b = torch.zeros_like(a) + for i in range(M): + for j in range(N): + # ca < i < cb where ca=16, cb=96 + if 16 < i < 96: + ref_b[i, j] = a[i, j] * 2.0 + else: + ref_b[i, j] = a[i, j] * 0.5 + + torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2) + + +def test_tilelang_if_range(): + run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_infinity.py b/tilelang/original/testing/python/language/test_tilelang_language_infinity.py new file mode 100644 index 0000000000000000000000000000000000000000..746afc4e0404cb495a1ee77b02aa13c005ca8c6f --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_infinity.py @@ -0,0 +1,32 @@ +import torch +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=-1) +def get_inf_kernel(dtype: str): + @T.prim_func + def main(A: T.Tensor((32,), dtype)): + with T.Kernel(1, threads=32): + T.fill(A, T.infinity(dtype)) + + return main + + +def _test_infinity(dtype: str): + kernel = get_inf_kernel(dtype) + output = kernel() + + assert torch.all(output == torch.inf), f"check failed for {dtype=}" + + +@tilelang.testing.requires_cuda +def test_infinity(): + _test_infinity(T.float16) + _test_infinity(T.bfloat16) + _test_infinity(T.float32) + _test_infinity(T.float64) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_int64.py b/tilelang/original/testing/python/language/test_tilelang_language_int64.py new file mode 100644 index 0000000000000000000000000000000000000000..d81e9dc6fab289d6dae760830136b86a2769dc78 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_int64.py @@ -0,0 +1,66 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit +def fill_symbolic(value: float, dtype=T.bfloat16): + n = T.symbolic("n", "int64") + block_n = 512 + + @T.prim_func + def main(x: T.Tensor[n, dtype]): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx: + # Doesn't yet work with int64-shaped global tensor + # T.fill(x[bx * block_n : (bx + 1) * block_n], value) + for i in T.Parallel(block_n): + x[bx * block_n + i] = value + + return main + + +def run_fill_symbolic(n: int): + import torch + + x = torch.zeros(n, dtype=torch.bfloat16, device="cuda") + fill_symbolic(1.0)(x) + assert x.min() == 1.0 and x.max() == 1.0 + + +def test_fill_symbolic(): + # Requires 8GB VRAM + run_fill_symbolic(2**32) + + +@tilelang.jit +def fill_static(n: int, value: float, dtype=T.bfloat16): + block_n = 512 + + @T.prim_func + def main(x: T.Tensor[n, dtype]): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx: + # Doesn't yet work with int64-shaped global tensor + # T.fill(x[bx * block_n : (bx + 1) * block_n], value) + for i in T.Parallel(block_n): + x[bx * block_n + i] = value + + return main + + +def run_fill_static(n: int): + import torch + + x = torch.zeros(n, dtype=torch.bfloat16, device="cuda") + fill_static(n, 1.0)(x) + assert x.min() == 1.0 and x.max() == 1.0 + + +def test_fill_static(): + # Requires 8GB VRAM + run_fill_static(2**32) + + +if __name__ == "__main__": + test_fill_symbolic() + test_fill_static() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_intrinsics_codegen.py b/tilelang/original/testing/python/language/test_tilelang_language_intrinsics_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d1e5401ee4216515e7875e491f48a933f3e50a --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_intrinsics_codegen.py @@ -0,0 +1,30 @@ +import tilelang +import tilelang.language as T +import tilelang.testing + + +@tilelang.testing.requires_cuda +def test_language_ldg_codegen(): + N = 128 + + @T.prim_func + def main( + x: T.Tensor((N,), T.float32), + y: T.Tensor((N,), T.float32), + ): + with T.Kernel(N, threads=32) as pid: + # Explicitly request read-only cache load for x[pid] + y[pid] = T.__ldg(x[pid]) + 1.0 + + # Compile for CUDA and retrieve generated CUDA source + kernel = tilelang.compile(main, out_idx=[1], target="cuda") + src = kernel.get_kernel_source() + print(src) + # Assert that codegen uses __ldg on CUDA backend + # We look for the intrinsic call with address-of argument + assert "__ldg(" in src, "Expected __ldg call in generated CUDA source" + assert "__ldg(&" in src or "__ldg(&(" in src, "Expected address-of form in __ldg call" + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_lazy_jit.py b/tilelang/original/testing/python/language/test_tilelang_language_lazy_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..505730965463d1d7431286222db596743f058351 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_lazy_jit.py @@ -0,0 +1,418 @@ +from dataclasses import dataclass, field +import tilelang.testing +import tilelang +import tilelang.language as T +from typing import Any +from itertools import product +import torch + + +def _gemm_impl(): + @T.macro + def gemm_impl( + A: T.Tensor[[int, int], Any], + B: T.Tensor[[int, int], Any], + C: T.Tensor[[int, int], Any], + out_dtype: T.dtype, + block_M: int, + block_N: int, + block_K: int, + ): + dtype = A.dtype + M, K = A.shape + K, N = B.shape + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), out_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[bx * block_M, by * block_N]) + + return gemm_impl + + +def test_jit2_gemm_annot(): + @tilelang.lazy_jit + def gemm( + A: T.Tensor[[int, int], Any], + B: T.Tensor[[int, int], Any], + out_dtype: T.dtype = T.float32, + block_M: int = 64, + block_N: int = 64, + block_K: int = 32, + ): + M, K = A.shape + K, N = B.shape + C = T.empty(M, N, dtype=out_dtype) + _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) + return C + + prod = product([T.float16, T.float32], [T.float32]) + gemm.par_compile( + [ + {"A": T.Tensor((1024, 1024), dtype=in_dtype), "B": T.Tensor((1024, 1024), dtype=in_dtype), "out_dtype": out_dtype} + for in_dtype, out_dtype in prod + ] + ) + + for in_dtype, out_dtype in prod: + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() + A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + C_ref = out_dtype(A @ B) + C = gemm(A, B) + torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2) + + +def test_jit2_gemm_ptr(): + @tilelang.lazy_jit + def gemm_ptr( + A: T.ptr, + B: T.ptr, + C: T.ptr, + M: int, + N: int, + K: int, + dtype: T.dtype, + out_dtype: T.dtype, + block_M: int = 64, + block_N: int = 64, + block_K: int = 32, + ): + A = T.make_tensor(A, (M, K), dtype) + B = T.make_tensor(B, (K, N), dtype) + C = T.make_tensor(C, (M, N), out_dtype) + _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) + + prod = product([T.float16, T.float32], [T.float32]) + gemm_ptr.par_compile( + [ + {"A": T.ptr(), "B": T.ptr(), "C": T.ptr(), "M": 1024, "N": 1024, "K": 1024, "dtype": in_dtype, "out_dtype": out_dtype} + for in_dtype, out_dtype in prod + ] + ) + for in_dtype, out_dtype in prod: + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() + A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + C_ref = out_dtype(A @ B) + C = torch.empty(1024, 1024, dtype=out_dtype, device="cuda") + gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype) + torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2) + + +def test_jit2_annot(): + from tilelang.language.v2.annot import Annot, ArgVarTable + from tilelang.language.v2.builder import Builder + import traceback + + @dataclass + class AnnotTest: + annot: Annot + promote: Any + match_ok: list[Any] = field(default_factory=list) + match_ng: list[Any] = field(default_factory=list) + + tests = [ + AnnotTest( + annot=T.Tensor[[int, int], T.float32], + promote=False, + match_ok=[torch.randn(1, 1, dtype=torch.float32), T.Tensor((1, 1), dtype=T.float32)], + match_ng=[ + torch.randn(1, 1, dtype=torch.float16), + T.Tensor(1, dtype=T.float32), + T.Tensor((1, 1), dtype=T.float16), + ], + ), + AnnotTest( + annot=T.Tensor[[int], Any], + promote=False, + match_ok=[ + torch.randn(12, dtype=torch.float32), + torch.randn(12, dtype=torch.float16), + T.Tensor((1,), dtype=T.float32), + T.Tensor((1,), dtype=T.float16), + ], + match_ng=[torch.randn((1, 1), dtype=torch.float32), T.Tensor((1, 1), dtype=T.float16)], + ), + AnnotTest( + annot=T.Tensor[[int, 1], Any], + promote=False, + match_ok=[ + torch.randn(12, 1, dtype=torch.float32), + torch.randn(12, 1, dtype=torch.float16), + T.Tensor((12, 1), T.float32), + T.Tensor((12, 1), T.float16), + ], + match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)], + ), + AnnotTest( + annot=T.Tensor[[T.dyn, 1], Any], + promote=False, + match_ok=[ + torch.randn(12, 1, dtype=torch.float32), + torch.randn(12, 1, dtype=torch.float16), + T.Tensor((12, 1), T.float32), + T.Tensor((12, 1), T.float16), + ], + match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)], + ), + AnnotTest( + annot=T.Tensor[[1024, 1024], T.float32], + promote=True, + ), + AnnotTest(annot=T.dyn[int, "X"], promote=False, match_ok=[1, 2, 3, 4]), + AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]), + ] + + for test in tests: + promote = test.annot.promote() + promoted = promote is not None + if promoted != test.promote: + raise AssertionError(f"Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}") + with Builder().prim_func("_test"): + for match_ok in test.match_ok: + try: + vt = ArgVarTable() + test.annot.create_prim_func_arg("arg", match_ok, vt) + except Exception as e: + traceback.print_exc() + raise AssertionError(f"Match failed for {test.annot} with value {match_ok}: {e}") from e + for match_ng in test.match_ng: + try: + vt = ArgVarTable() + test.annot.create_prim_func_arg("arg", match_ng, vt) + raise AssertionError(f"Match unexpectedly succeeded for {test.annot} with value {match_ng}") + except Exception: + pass + + +def test_jit2_many_annot(): + @T.macro + def copy_impl(A, B): + M, N = A.shape + M_, N_ = B.shape + assert M == M_, f"M mismatch {M} {M_}" + assert N == N_, f"N mismatch {N} {N_}" + # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" + with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): + T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + + @tilelang.lazy_jit + def copy1( + A: T.Tensor[[int, int], T.float32], + B: T.Tensor[[int, int], T.float32], + ): + copy_impl(A, B) + + @tilelang.lazy_jit + def copy2( + A: T.Tensor[[128, 128], T.float32], + B: T.Tensor[[128, 128], T.float32], + ): + copy_impl(A, B) + + @tilelang.lazy_jit + def copy3( + A: T.Tensor[[int, 128], T.float32], + B: T.Tensor[[int, 128], T.float32], + ): + copy_impl(A, B) + + @tilelang.lazy_jit + def copy4( + A: T.Tensor[[T.dyn, int], T.float32], + B: T.Tensor[[T.dyn, int], T.float32], + ): + copy_impl(A, B) + + @tilelang.lazy_jit + def copy5( + A: T.StridedTensor[[int, int], [int, int], T.float32], + B: T.StridedTensor[[int, int], [int, int], T.float32], + ): + copy_impl(A, B) + + @tilelang.lazy_jit + def copy6( + A: T.StridedTensor[[T.dyn, int], [int, int], T.float32], + B: T.StridedTensor[[T.dyn, int], [int, int], T.float32], + ): + copy_impl(A, B) + + for copy in [copy1, copy2, copy3, copy4]: + A = torch.randn(128, 128, device="cuda") + B = torch.empty(128, 128, device="cuda") + copy(A, B) + assert torch.equal(B, A) + + for copy in [copy5, copy6]: + A = torch.randn(128, 2, 128, 2, device="cuda") + B = torch.randn(128, 2, 128, 2, device="cuda") + copy(A[:, 0, :, 0], B[:, 0, :, 0]) + assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0]) + + +def test_jit2_return(): + @T.macro + def copy_impl(A): + M, N = A.shape + B = T.empty(M, N, dtype=A.dtype) + M, N = A.shape + M_, N_ = B.shape + assert M == M_, f"M mismatch {M} {M_}" + assert N == N_, f"N mismatch {N} {N_}" + # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" + with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): + T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + return B + + @tilelang.lazy_jit + def copy0(A: T.Tensor[[int, int], Any]): + return copy_impl(A) + + @tilelang.lazy_jit + def copy1( + A: T.Tensor[[int, int], T.float32], + ): + return copy_impl(A) + + @tilelang.lazy_jit + def copy2( + A: T.Tensor[[128, 128], T.float32], + ): + return copy_impl(A) + + @tilelang.lazy_jit + def copy3( + A: T.Tensor[[int, 128], T.float32], + ): + return copy_impl(A) + + @tilelang.lazy_jit + def copy4( + A: T.Tensor[[T.dyn, int], T.float32], + ): + return copy_impl(A) + + @tilelang.lazy_jit + def copy5( + A: T.StridedTensor[[int, int], [int, int], T.float32], + ): + return copy_impl(A) + + @tilelang.lazy_jit + def copy6( + A: T.StridedTensor[[T.dyn, int], [int, int], T.float32], + ): + return copy_impl(A) + + for copy in [copy0, copy1, copy2, copy3, copy4]: + A = torch.randn(128, 128, device="cuda") + B = copy(A) + assert torch.equal(B, A) + for copy in [copy5, copy6]: + A = torch.randn(128, 2, 128, 2, device="cuda") + B = copy(A[:, 0, :, 0]) + assert torch.equal(A[:, 0, :, 0], B) + + +def test_jit2_deepseek_deepgemm(): + @tilelang.lazy_jit + def deep_gemm( + A: T.Tensor[[int, int], T.float8_e4m3fn], + B: T.Tensor[[int, int], T.float8_e4m3fn], + scales_a: T.Tensor[[int, int], T.float32], + scales_b: T.Tensor[[int, int], T.float32], + out_dtype: T.dtype = T.bfloat16, + accum_dtype: T.dtype = T.float32, + block_N: int = 128, + block_M: int = 128, + block_K: int = 128, + ): + # A: [M, K] + # B: [N, K] + # scales_a: [M, K // 128] + # scales_b: [N, K // 128] + # C: [M, N] + + group_size = 128 + in_dtype = A.dtype + M, K = A.shape + N, K = B.shape + C = T.empty(M, N, dtype=out_dtype) + + assert out_dtype in [T.bfloat16, T.float32], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" + assert scales_a.shape == [M, T.ceildiv(K, group_size)], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}" + assert scales_b.shape == [N, T.ceildiv(K, group_size)], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}" + + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), in_dtype) + B_shared = T.alloc_shared((block_N, block_K), in_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + scale_C_shared = T.alloc_shared((block_M,), T.float32) + C_local = T.alloc_fragment((block_M, block_K), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * scale_C_shared[i] + T.clear(C_local) + + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return C + + +# def ceildiv(a, b): +# return (a + b - 1) // b + +# def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): +# # A_scale: (M, K//128) ==> (M//128, K//128, 128) +# # B_scale: (N//128, K//128) ==> (N//128, K//128, 128) +# # A_fp8: (M, K) +# # B_fp8: (N, K) +# # out_dtype: float16 or float32 +# # return C: (M, N) +# M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1] +# A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1) +# B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128) +# C = torch.zeros(M, N, device="cuda", dtype=out_dtype) +# c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32) +# for i in range(ceildiv(M, 128)): +# for j in range(ceildiv(N, 128)): +# c_acc.zero_() +# for k in range(ceildiv(K, 128)): +# c = torch._scaled_mm( +# A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], +# B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, +# scale_a=A_scales[i, k].view(128, 1).contiguous(), +# scale_b=B_scales[j, k].view(1, 128).contiguous(), +# out_dtype=torch.bfloat16) +# c_acc += c.to(torch.float32) +# C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) +# return C + +# M, N, K = 1024, 1024, 8192 +# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, ) + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_let.py b/tilelang/original/testing/python/language/test_tilelang_language_let.py new file mode 100644 index 0000000000000000000000000000000000000000..6f94ad66493769928be0526d2a3f60b08057ebae --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_let.py @@ -0,0 +1,22 @@ +import tilelang.testing +from tilelang import tvm as tvm +from tilelang import language as T + + +def test_let_vectorize_load(): + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): + b = A[0, 0:4] + A[0, 4:8] = b + + mod = tvm.IRModule({"main": main}) + mod = tvm.compile(mod, target="cuda") + assert "float4 b" in mod.mod.imports[0].inspect_source() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_let_layout.py b/tilelang/original/testing/python/language/test_tilelang_language_let_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..fec30b914b2f8075c18fbd305382802b37686dd3 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_let_layout.py @@ -0,0 +1,123 @@ +""" +Test layout inference for LetStmt expressions. + +This test validates that TileLang correctly handles layout inference when +fragment buffer accesses occur through let bindings. For example: + + block_mask_f = T.alloc_fragment((N_S,), T.int32) + T.copy(BlockMask[by, :], block_mask_f) + for i in T.Pipelined(N_S): + a = block_mask_f[i] # LetStmt: a is bound to fragment buffer load + T.copy(A[a, 0], A_shared) # a is used as index in TMA copy + +Key scenarios tested: +1. Fragment buffer layout inference through let bindings +2. TMA (Tensor Memory Accelerator) copy with let-bound indices +3. CP.ASYNC copy with let-bound indices +4. Warp specialization with let-bound fragment accesses +""" + +import tilelang +import tilelang.language as T +import tilelang.testing +import torch + + +def blocksparse_copy_kernel(M, N, N_S, block_M, block_N, dtype=T.float16): + """BlockSparse copy kernel using fragment for block mask indices.""" + block_mask_shape = (M // block_M, N_S) + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) + block_mask_f = T.alloc_fragment((N_S,), T.int32) + + T.clear(B_shared) + T.copy(BlockMask[by, :], block_mask_f) + for i in T.Pipelined(N_S): + a = block_mask_f[i] # LetStmt: fragment buffer access + if a >= 0: + T.copy(A[a, 0], A_shared) + T.copy(A_shared, B[by * block_M : (by + 1) * block_M, i * block_N : (i + 1) * block_N]) + + return main + + +def ref_blocksparse_copy(A, B, BlockMask, M, N, N_S, block_M, block_N): + """Reference implementation for blocksparse copy.""" + ref_B = B.clone() + num_row_blocks = M // block_M + + for by in range(num_row_blocks): + for i in range(N_S): + src_row_start = BlockMask[by, i].item() + ref_B[by * block_M : (by + 1) * block_M, i * block_N : (i + 1) * block_N] = A[ + src_row_start : src_row_start + block_M, 0:block_N + ] + + return ref_B + + +def run_blocksparse_copy(M, N, block_M, block_N, pass_configs=None): + """Run blocksparse copy test with given parameters.""" + N_S = N // block_N + + program = blocksparse_copy_kernel(M, N, N_S, block_M, block_N) + kernel = tilelang.compile( + program, + out_idx=[1], + target="cuda", + pass_configs=pass_configs or {}, + ) + + # Initialize tensors + a = torch.randn(M, N, device="cuda", dtype=torch.float16) + b = torch.zeros(M, N, device="cuda", dtype=torch.float16) + + # Create BlockMask with valid row indices + num_row_blocks = M // block_M + block_mask = torch.zeros((num_row_blocks, N_S), dtype=torch.int32, device="cuda") + for by in range(num_row_blocks): + for i in range(N_S): + max_row_block = (M - block_M) // block_M + block_mask[by, i] = torch.randint(0, max_row_block + 1, (1,)).item() * block_M + + # Run kernel + c = kernel(a, block_mask) + + # Compute reference + ref_c = ref_blocksparse_copy(a, b, block_mask, M, N, N_S, block_M, block_N) + + # Verify + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +def test_blocksparse_copy_tma(): + """Test blocksparse copy with TMA (Tensor Memory Accelerator).""" + run_blocksparse_copy(M=1024, N=1024, block_M=128, block_N=128, pass_configs={}) + + +@tilelang.testing.requires_cuda +def test_blocksparse_copy_cp_async(): + """Test blocksparse copy with CP.ASYNC (without TMA).""" + run_blocksparse_copy( + M=1024, + N=1024, + block_M=128, + block_N=128, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_mask_op.py b/tilelang/original/testing/python/language/test_tilelang_language_mask_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8f899729133357cd7cfe322a122ffeef079ac2ee --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_mask_op.py @@ -0,0 +1,153 @@ +import tilelang +import tilelang.language as T +import torch + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + tx = T.get_thread_binding(0) + + if tx < 128: + for i, k in T.Parallel(block_M, block_N): + A_shared[i, k] = A[by * block_M + i, bx * block_N + k] + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return main + + +def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_mask_parallel(): + run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128) + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + tx = T.get_thread_binding(0) + + if tx < 128: + T.copy(A[by * block_M, bx * block_N], A_shared) + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return main + + +def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_mask_copy(): + run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128) + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + tx = T.get_thread_binding(0) + + if tx >= 128 and tx < 256: + for i, k in T.Parallel(block_M, block_N): + A_shared[i, k] = A[by * block_M + i, bx * block_N + k] + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return main + + +def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_mask_parallel_range(): + run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128) + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + tx = T.get_thread_binding(0) + + if tx >= 128 and tx < 256: + T.copy(A[by * block_M, bx * block_N], A_shared) + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return main + + +def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_mask_copy_range(): + run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128) + + +if __name__ == "__main__": + test_tilelang_copy_mask_copy_range() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_negative_index.py b/tilelang/original/testing/python/language/test_tilelang_language_negative_index.py new file mode 100644 index 0000000000000000000000000000000000000000..feeed2c6fdd89573686cae588d20bfbc0490dd7b --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_negative_index.py @@ -0,0 +1,59 @@ +from tilelang import tvm +import tilelang as tl +import tilelang.testing +import tilelang.language as T + + +@T.prim_func +def negative_index_before(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)): + T.func_attr({"tir.noalias": True}) + B[0] = A[T.int32(-1)] + + +@T.prim_func +def negative_index_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)): + T.func_attr({"tir.noalias": True}) + B[0] = A[T.int32(15)] + + +@T.prim_func +def negative_index_loop_before(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)): + T.func_attr({"tir.noalias": True}) + for i in T.serial(4): + B[i] = A[-i - 1] + + +@T.prim_func +def negative_index_loop_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)): + T.func_attr({"tir.noalias": True}) + for i in T.serial(4): + B[i] = A[15 - i] + + +@T.prim_func +def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)): + T.func_attr({"tir.noalias": True}) + for i in T.serial(16): + B[i] = A[shift + i] + + +def test_legalize_negative_index_scalar(): + mod = tvm.IRModule({"main": negative_index_before}) + transformed = tl.transform.LegalizeNegativeIndex()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body) + + +def test_legalize_negative_index_affine_expr(): + mod = tvm.IRModule({"main": negative_index_loop_before}) + transformed = tl.transform.LegalizeNegativeIndex()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body) + + +def test_legalize_negative_index_symbolic_passthrough(): + mod = tvm.IRModule({"main": negative_index_symbolic_before}) + transformed = tl.transform.LegalizeNegativeIndex()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_parallel.py b/tilelang/original/testing/python/language/test_tilelang_language_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..a392e70b687ce892ba4455dd5385f9b8f07f06f0 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_parallel.py @@ -0,0 +1,70 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import pytest + +tilelang.testing.set_random_seed() + + +@tilelang.jit(out_idx=[1]) +def parallel_elementwise_static(length=256, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length): + B[i] = A[i] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def parallel_elementwise_dynamic(max_len=512, threads=256, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((max_len,), dtype), + B: T.Tensor((max_len,), dtype), + valid_len: T.int32, + ): + with T.Kernel(1, threads=threads) as _: + for i in T.Parallel(max_len): + B[i] = 0.0 + span = T.min(valid_len, max_len) + for i in T.Parallel(span): + B[i] = A[i] - 1.0 + + return main + + +def _require_cuda_tensor(shape, dtype=torch.float32): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randn(*shape, device="cuda", dtype=dtype) + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +def test_parallel_static_extent(): + kernel = parallel_elementwise_static(length=256) + data = _require_cuda_tensor((256,), torch.float32) + result = kernel(data) + torch.testing.assert_close(result, data + 1.0, atol=1e-5, rtol=1e-5) + + +def test_parallel_dynamic_extent(): + kernel = parallel_elementwise_dynamic(max_len=512, threads=256) + data = _require_cuda_tensor((512,), torch.float32) + for valid_len in [0, 13, 200, 600]: + out = kernel(data, valid_len) + reference = torch.zeros_like(data) + clip = min(valid_len, data.shape[0]) + reference[:clip] = data[:clip] - 1.0 + torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_pipeline.py b/tilelang/original/testing/python/language/test_tilelang_language_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8136e246f0e48b2697a13376f773172333ce83cb --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_pipeline.py @@ -0,0 +1,213 @@ +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + order, + stage, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + trans_A = False + trans_B = False + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 + num_threads = 128 + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_pipeline_order_stage(): + run_gemm(order=[0, 1, 2], stage=[0, 0, 1]) + run_gemm(order=[0, 1, 2], stage=[0, 0, 2]) + run_gemm(order=[1, 2, 0], stage=[0, 0, 2]) + run_gemm(order=[1, 2, 0], stage=[0, 0, 1]) + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype=T.float16, accum_dtype=T.float32): + block_mask_shape = (M // block_M, N // block_N, K // block_K) + + import tilelang.language as T + + @T.prim_func + def block_sparse_matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + block_mask = T.alloc_local((1,), "bool") + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + block_mask[0] = BlockMask[by, bx, k] + if block_mask[0]: + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return block_sparse_matmul + + +def run_blocksparse_matmul(num_stages): + import torch + + M = 256 + N = 256 + K = 256 + block_M = 128 + block_N = 128 + block_K = 32 + sparsity = 0.5 + + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + kernel = blocksparse_matmul(M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages) + print(kernel.get_kernel_source()) + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + def ref_program(A, B, BlockMask, block_M, block_N, block_K): + ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + for i in range(M // block_M): + for j in range(N // block_N): + accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) + for k in range(K // block_K): + if BlockMask[i, j, k]: + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) + return ref_c + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def test_blocksparse_matmul(): + run_blocksparse_matmul(num_stages=1) + run_blocksparse_matmul(num_stages=2) + run_blocksparse_matmul(num_stages=3) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_ptr.py b/tilelang/original/testing/python/language/test_tilelang_language_ptr.py new file mode 100644 index 0000000000000000000000000000000000000000..85458139a5529deda0e0b341516bcfea8559b96d --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_ptr.py @@ -0,0 +1,66 @@ +import torch +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl +import tilelang.language as T +from tilelang.utils import map_torch_type + + +def matmul_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + a_ptr: T.ptr, + b_ptr: T.ptr, + c_ptr: T.ptr, + m: T.int32, + n: T.int32, + k: T.int32, + ): + A = T.make_tensor(a_ptr, (m, k), dtype) + B = T.make_tensor(b_ptr, (k, n), dtype) + C = T.make_tensor(c_ptr, (m, n), accum_dtype) + + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3): + # Copy tile of A + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) + jit_kernel = tl.compile(program, target="cuda", execution_backend="cython") + + def ref_program(a, b): + return (a @ b.T).to(torch.float32) + + a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) + b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) + + c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) + + jit_kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), M, N, K) + + ref_c = (a @ b.T).to(map_torch_type(accum_dtype)) + + torch.testing.assert_close(c, ref_c, atol=1e-2, rtol=1e-2) + + +def test_matmul(): + run_matmul(1024, 1024, 1024, 128, 128, 32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_rand.py b/tilelang/original/testing/python/language/test_tilelang_language_rand.py new file mode 100644 index 0000000000000000000000000000000000000000..daf51dbb7f84c575a5a1ffc9d12ecf954b9ab311 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_rand.py @@ -0,0 +1,37 @@ +import tilelang +import tilelang.language as T +import torch +import pytest +import tilelang.testing + + +@tilelang.jit +def tilelang_rand_1d(M=1024, seed=42): + num_per_thread = 128 + threads = 1 + blk_M = num_per_thread * threads + + @T.prim_func + def rand_kernel(A: T.Tensor((M,), "uint32")): + with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx: + tx = T.get_thread_binding() + T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread) + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + A[idx] = T.rng_rand() + + return rand_kernel + + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize("M, seed", [(1024, 42), (512, 123), (128, 0)]) +def test_rand_1d(M, seed): + kernel = tilelang_rand_1d(M, seed) + tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda") + kernel(tilelang_result) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_reduce.py b/tilelang/original/testing/python/language/test_tilelang_language_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9bf61303d356a215ba7d42245a00c705de3504 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_reduce.py @@ -0,0 +1,223 @@ +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl +import tilelang.language as T + +tilelang.testing.set_random_seed() + + +def _make_shared_reduce(M, N, dtype, reduce_cb): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1) as _: + A_shared = T.alloc_shared((M, N), dtype) + B_shared = T.alloc_shared((M,), dtype) + + T.copy(A, A_shared) + reduce_cb(T, A_shared, B_shared) + T.copy(B_shared, B) + + return main + + +def _run_program(program, ref_program, atol=1e-2, rtol=1e-2): + jit_kernel = tl.compile(program, out_idx=-1) + profiler = jit_kernel.get_profiler() + profiler.assert_allclose(ref_program, atol=atol, rtol=rtol) + + +def reduce_max_test(M, N, dtype=T.float16): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.reduce_max(A_local, B_local, dim=1) + T.copy(B_local, B) + + return main + + +def reduce_sum_test(M, N, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.reduce_sum(A_local, B_local, dim=1) + T.copy(B_local, B) + + return main + + +def reduce_sum_ss(M, N, dtype=T.float32): + return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1)) + + +def reduce_max_ss(M, N, dtype=T.float32): + return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1)) + + +def reduce_min_ss(M, N, dtype=T.float32): + return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1)) + + +def reduce_abssum_ss(M, N, dtype=T.float32): + return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1)) + + +def reduce_absmax_ss(M, N, dtype=T.float32): + return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1)) + + +def run_reduce_sum(M, N, dtype=T.float32, mode="rr"): + if mode == "rr": + program = reduce_sum_test(M, N, dtype) + elif mode == "ss": + program = reduce_sum_ss(M, N, dtype) + else: + raise NotImplementedError("run_reduce_sum only supports rr and ss") + _run_program(program, lambda A: A.sum(dim=1)) + + +def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32): + program = program_builder(M, N, dtype) + _run_program(program, ref_program) + + +def run_reduce_max(M, N, dtype=T.float16): + program = reduce_max_test(M, N, dtype) + _run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum(): + run_reduce_sum(256, 256) + run_reduce_sum(512, 128) + run_reduce_sum(128, 512) + + +def test_reduce_sum_shared(): + run_reduce_sum(64, 64, mode="ss") + + +def test_reduce_max(): + run_reduce_max(256, 256, T.float16) + run_reduce_max(512, 128, T.float16) + run_reduce_max(256, 256, T.float32) + + +def test_reduce_max_shared(): + run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, T.float32) + + +def test_reduce_min_shared(): + run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, T.float32) + + +def test_reduce_abssum_shared(): + run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, T.float32) + + +def test_reduce_absmax_shared(): + run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, T.float32) + + +def reduce_sum_test_clear(M, N, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.fill(B_local, 1) + T.reduce_sum(A_local, B_local, dim=1, clear=False) + T.copy(B_local, B) + + return main + + +def run_reduce_sum_clear(M, N, dtype=T.float32): + program = reduce_sum_test_clear(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + + def ref_program(A): + return A.sum(dim=1) + 1 + + import torch + + dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda() + ref_out = ref_program(dummy_A) + tl_out = jit_kernel(dummy_A) + torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum_clear(): + run_reduce_sum_clear(256, 256, T.float32) + run_reduce_sum_clear(512, 128, T.float32) + run_reduce_sum_clear(128, 512, T.float32) + + +def reduce_max_test_clear(M, N, dtype=T.float16): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.fill(B_local, -T.infinity(dtype)) + T.reduce_max(A_local, B_local, dim=1, clear=False) + T.copy(B_local, B) + + return main + + +def run_reduce_max_clear(M, N, dtype=T.float16): + program = reduce_max_test_clear(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + + def ref_program(A): + return A.max(dim=1).values + + import torch + + dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda() + ref_out = ref_program(dummy_A) + tl_out = jit_kernel(dummy_A) + torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + + +def test_reduce_max_clear(): + run_reduce_max_clear(256, 256, T.float16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_reshape.py b/tilelang/original/testing/python/language/test_tilelang_language_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..10c3d0ce87b76b85e18946fa2569f7a12dec3433 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_reshape.py @@ -0,0 +1,282 @@ +import tilelang.testing +import tilelang as tl +from tilelang import language as T +import torch +import pytest + + +def reshape_test(N, M, dtype): + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), + ): + with T.Kernel(1) as _: + A_reshaped = T.reshape(A, [N // M, M]) + T.copy(A_reshaped, B) + + return main + + +def run_reshape(N, M, dtype): + program = reshape_test(N, M, dtype) + # TODO(lei): reshape cannot apply shared memory + # layout transform propagation + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.reshape(N // M, M) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reshape_smem(): + # Test reshape + run_reshape(1024, 32, T.float32) + run_reshape(2048, 64, T.float16) + + +def reshape_test_smem_1d_2_2d(N, M, dtype): + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), + ): + with T.Kernel(1) as _: + A_shared = T.alloc_shared((N,), dtype) + for i in T.Parallel(N): + A_shared[i] = A[i] + + A_smem_reshaped = T.reshape(A_shared, [N // M, M]) + T.copy(A_smem_reshaped, B) + + return main + + +def run_reshape_smem_1d_2_2d(N, M, dtype): + program = reshape_test_smem_1d_2_2d(N, M, dtype) + # TODO(lei): reshape cannot apply shared memory + # layout transform propagation + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.reshape(N // M, M) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reshape_smem_1d_2_2d(): + run_reshape_smem_1d_2_2d(1024, 32, T.float32) + run_reshape_smem_1d_2_2d(2048, 64, T.float16) + + +def reshape_test_smem_2d_2_1d(N, M, dtype): + @T.prim_func + def main( + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(1) as _: + A_shared = T.alloc_shared((N // M, M), dtype) + for i, j in T.Parallel(N // M, M): + A_shared[i, j] = A[i, j] + + A_smem_reshaped = T.reshape(A_shared, [N]) + T.copy(A_smem_reshaped, B) + + return main + + +def run_reshape_smem_2d_2_1d(N, M, dtype): + program = reshape_test_smem_2d_2_1d(N, M, dtype) + # TODO(lei): reshape cannot apply shared memory + # layout transform propagation + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.reshape(N) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reshape_smem_2d_2_1d(): + run_reshape_smem_2d_2_1d(1024, 32, T.float32) + run_reshape_smem_2d_2_1d(2048, 64, T.float16) + + +def reshape_fragment_test(N, M, dtype): + @T.prim_func + def main( + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") + A_local = T.alloc_fragment((N // M, M), dtype) + B_shared = T.alloc_shared((N,), dtype, scope="shared") + + T.copy(A, A_shared) + T.copy(A_shared, A_local) + A_local_reshape = T.reshape(A_local, [N]) + T.copy(A_local_reshape, B_shared) + T.copy(B_shared, B) + + return main + + +def run_reshape_fragment(N, M, dtype): + program = reshape_fragment_test(N, M, dtype) + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.reshape(N) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reshape_fragment(): + run_reshape_fragment(1024, 32, T.float32) + run_reshape_fragment(2048, 64, T.float16) + + +def reshape_layout_transform_shared(N, M, dtype): + from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout + + @T.prim_func + def main( + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") + + T.annotate_layout( + { + A_shared: make_mma_swizzle_layout(A_shared), + } + ) + T.copy(A, A_shared) + A_shared_reshape = T.reshape(A_shared, [N]) + T.copy(A_shared_reshape, B) + + return main + + +def run_reshape_layout_transform_shared(N, M, dtype): + program = reshape_layout_transform_shared(N, M, dtype) + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.reshape(N) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reshape_layout_transform_shared(): + run_reshape_layout_transform_shared(1024, 32, T.float32) + run_reshape_layout_transform_shared(2048, 64, T.float16) + + +def reduce_after_reshape_test(N, M, dtype): + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_shared = T.alloc_shared((N,), dtype, scope="shared") + A_local = T.alloc_fragment((N,), dtype) + B_local = T.alloc_fragment((N // M,), dtype) + + T.copy(A, A_shared) + T.copy(A_shared, A_local) + A_local_reshape = T.reshape(A_local, [N // M, M]) + T.reduce_max(A_local_reshape, B_local, dim=1) + T.copy(B_local, B) + + return main + + +def run_reduce_after_reshape(N, M, dtype): + program = reduce_after_reshape_test(N, M, dtype) + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return torch.max(A.reshape(N // M, M), dim=1).values + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reduce_after_reshape(): + run_reduce_after_reshape(1024, 32, T.float32) + run_reduce_after_reshape(2048, 64, T.float16) + + +def reshape_shape_mismatch_test(N, M, dtype): + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), + ): + with T.Kernel(1) as _: + A_reshaped = T.reshape(A, [N // M, M + 1]) + T.copy(A_reshaped, B) + + return main + + +def test_reshape_shape_mismatch(): + with pytest.raises(AssertionError): + reshape_shape_mismatch_test(1024, 32, T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_ternary.py b/tilelang/original/testing/python/language/test_tilelang_language_ternary.py new file mode 100644 index 0000000000000000000000000000000000000000..20c7b5e778ce8f0bb76aa1250fa3e8861b321619 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_ternary.py @@ -0,0 +1,44 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing + + +@tilelang.jit( + out_idx=[1], +) +def tilelang_ternary(M, N, block_M, block_N, dtype=T.float16): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0 + + return main + + +def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype=T.float16): + kernel = tilelang_ternary(M, N, block_M, block_N, dtype) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + ref_b = torch.zeros_like(b) + for i in range(M): + for j in range(N): + if i < M // 2: + ref_b[i, j] = a[i, j] + else: + ref_b[i, j] = 0 + + torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2) + + +def test_tilelang_ternary(): + run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_tma_1d.py b/tilelang/original/testing/python/language/test_tilelang_language_tma_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb79c10c65c893c52e750c8458bb5397dd62e68 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_tma_1d.py @@ -0,0 +1,56 @@ +import torch +import tilelang +import tilelang.language as T + + +def ref_program(x, y): + return x + y + + +@tilelang.jit(out_idx=[-1]) +def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): + @T.prim_func + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), in_dtype) + B_shared = T.alloc_shared((block_M, block_N), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) + for local_y, local_x in T.Parallel(block_M, block_N): + C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return elem_add + + +def run_elementwise_add(M, N): + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") + + # Default config + block_M, block_N = 128, 128 + config = {"block_M": block_M, "block_N": block_N, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) + + out = kernel(a, b) + torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) + + code = kernel.get_kernel_source() + if block_N == N: + assert "tma_load" in code and "CUtensorMap" not in code + else: + assert "tma_load" in code and "CUtensorMap" in code + + +def main(): + run_elementwise_add(128, 128) + run_elementwise_add(256, 128) + run_elementwise_add(256, 256) + + +if __name__ == "__main__": + main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_unroll.py b/tilelang/original/testing/python/language/test_tilelang_language_unroll.py new file mode 100644 index 0000000000000000000000000000000000000000..06367e975e59c82e03c1c5309a4707ca5f62602f --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_unroll.py @@ -0,0 +1,35 @@ +import tilelang.testing +from tilelang import tvm as tvm +from tilelang import language as T + + +def test_unroll_with_step(): + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): + for i in T.unroll(0, 16, step=4): + A[0, i] = 1.0 + + kernel = tilelang.compile(main, target="cuda") + assert "#pragma unroll" in kernel.get_kernel_source() + + +def test_unroll_with_unroll_factor(): + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): + for i in T.unroll(0, 16, unroll_factor=4): + A[0, i] = 1.0 + + kernel = tilelang.compile(main, target="cuda") + assert "#pragma unroll 4" in kernel.get_kernel_source() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_var_init.py b/tilelang/original/testing/python/language/test_tilelang_language_var_init.py new file mode 100644 index 0000000000000000000000000000000000000000..36d9bf01419205756571978b331df9f62833a668 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_var_init.py @@ -0,0 +1,30 @@ +import tilelang +import tilelang.language as T +import tilelang.testing + + +def test_var_assign() -> None: + @tilelang.jit(out_idx=-1) + def jit_kernel(): + @T.prim_func + def test_var_assign(A: T.Tensor((2,), T.int32)): + with T.Kernel(1) as _: + a = T.alloc_var(T.int32, init=1) + b = T.alloc_var(T.int32, init=a) # b gets value of a + a = 2 + d = T.alloc_var(T.int32, init=a) # c gets new value of a + A[0] = b + A[1] = d + + print(test_var_assign) + return test_var_assign + + kernel = jit_kernel() + print(kernel.get_kernel_source()) + res = kernel() + assert res[0] == 1 + assert res[1] == 2 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_vectorize.py b/tilelang/original/testing/python/language/test_tilelang_language_vectorize.py new file mode 100644 index 0000000000000000000000000000000000000000..75360bb19c48fe90b42ec35efd4057da8e193a81 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_vectorize.py @@ -0,0 +1,116 @@ +import torch +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) +def vectorize_test(N, M, stride_A, stride_B): + @T.prim_func + def main( + A: T.StridedTensor[(N, M), (1, stride_A), T.float32], # noqa: F821 + B: T.StridedTensor[(N, M), (1, stride_B), T.float32], # noqa: F821 + ): + with T.Kernel(M // 128, threads=128) as (bx): + tx = T.get_thread_binding(0) + col = bx * 128 + tx + + for row in T.vectorized(N): + B[row, col] = A[row, col] + + return main + + +def run_vectorize(N, M, stride_A, stride_B): + assert N % 128 == 0 and M % 128 == 0 + assert stride_A >= N and stride_B >= N + + jit_kernel = vectorize_test(N, M, stride_A, stride_B) + + base_a = torch.randn(stride_A, M, device="cuda", dtype=torch.float32) + base_b = torch.zeros(stride_B, M, device="cuda", dtype=torch.float32) + a = torch.as_strided(base_a, size=(N, M), stride=(1, stride_A)) + b = torch.as_strided(base_b, size=(N, M), stride=(1, stride_B)) + + jit_kernel(a, b) + + torch.testing.assert_close(a, b, atol=1e-8, rtol=1e-8) + + code = jit_kernel.get_kernel_source() + + vectorize_size = 1 + while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0: + vectorize_size *= 2 + + if vectorize_size == 4: + assert "float4" in code + elif vectorize_size == 2: + assert "float2" in code + + +def test_vectorize(): + N, M = 512, 256 + + run_vectorize(N, M, N, N) + run_vectorize(N, M, N + 2, N + 4) + run_vectorize(N, M, N + 4, N + 8) + run_vectorize(N, M, N + 8, N + 16) + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) +def vectorize_test_invariant_index(N, M, K): + @T.prim_func + def main( + A: T.Tensor[(N, M), T.float32], # noqa: F821 + B: T.Tensor[(N, M), T.float32], # noqa: F821 + C: T.Tensor[(N, M // K), T.float32], # noqa: F821 + ): + with T.Kernel(N // 128, threads=128) as (bx): + tx = T.get_thread_binding(0) + row = bx * 128 + tx + + for col in T.vectorized(M): + B[row, col] = A[row, col] * C[row, col // K] + + return main + + +def run_vectorize_invariant_index(N, M, K): + assert N % 128 == 0 and M % K == 0 + + jit_kernel = vectorize_test_invariant_index(N, M, K) + + a = torch.randn(N, M, device="cuda", dtype=torch.float32) + b = torch.zeros(N, M, device="cuda", dtype=torch.float32) + c = torch.randn(N, M // K, device="cuda", dtype=torch.float32) + + jit_kernel(a, b, c) + + indices = torch.arange(a.size(1)) // K + ret = a * c[:, indices] + torch.testing.assert_close(b, ret, atol=1e-8, rtol=1e-8) + + code = jit_kernel.get_kernel_source() + + vectorize_size = 1 + while vectorize_size <= 2 and K % (vectorize_size * 2) == 0: + vectorize_size *= 2 + + if vectorize_size == 4: + assert "float4" in code + elif vectorize_size == 2: + assert "float2" in code + + +def test_vectorize_invariant_index(): + N, M = 512, 256 + + run_vectorize_invariant_index(N, M, 2) + run_vectorize_invariant_index(N, M, 4) + run_vectorize_invariant_index(N, M * 3, 6) + run_vectorize_invariant_index(N, M, 8) + run_vectorize_invariant_index(N, M * 3, 12) + run_vectorize_invariant_index(N, M * 7, 14) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_vectorized_cast.py b/tilelang/original/testing/python/language/test_tilelang_language_vectorized_cast.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0a0942a25068b69596393741305fd441b1935a --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -0,0 +1,107 @@ +import pytest +import torch +import tilelang.testing +import tilelang.language as T + +str2dtype = { + T.float32: torch.float32, + T.float16: torch.float16, + T.bfloat16: torch.bfloat16, + T.float8_e4m3fn: torch.float8_e4m3fn, + T.float8_e5m2: torch.float8_e5m2, +} + + +@tilelang.jit +def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): + assert M % 256 == 0 + + @T.prim_func + def main( + A: T.Tensor[(M,), dtype_A], # noqa: F821 + B: T.Tensor[(M,), dtype_B], # noqa: F821 + ): + with T.Kernel(1, threads=128): + T.copy(A, B) + + return main + + +@tilelang.jit +def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): + assert M % 256 == 0 + + @T.prim_func + def main( + A: T.Tensor[(M,), dtype_A], # noqa: F821 + B: T.Tensor[(M,), dtype_B], # noqa: F821 + ): + with T.Kernel(1, threads=128): + A_local = T.alloc_fragment((M,), dtype_A) + B_local = T.alloc_fragment((M,), dtype_B) + + T.copy(A, A_local) + for i in T.Parallel(M): + B_local[i] = A_local[i] + T.copy(B_local, B) + + return main + + +def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2): + """Run the vectorized cast kernel and check the correctness. + Args: + src_dtype_str: The source data type string. + dst_dtype_str: The destination data type string. + check_str: Used to ensure vectorized cast is used. + lanes: The number of lanes of the source and destination data types. + """ + + M = 128 * lanes + kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) + kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) + + A_float = torch.randn(M, dtype=torch.float32, device="cuda") + A = A_float.to(str2dtype[src_dtype_str]) + B = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda") + C = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda") + + kernel(A, B) + kernel_parallel(A, C) + + torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B) + torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C) + + code = kernel.get_kernel_source() + code_parallel = kernel_parallel.get_kernel_source() + + assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" + + +@pytest.mark.parametrize( + "src_dtype, dst_dtype, check_str, lanes", + [ + (T.float32, T.float16, "__float22half2_rn", 2), + (T.float32, T.float16, "__float22half2_rn", 4), + (T.float16, T.float32, "__half22float2", 2), + (T.float16, T.float32, "__half22float2", 4), + (T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 2), + (T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 4), + (T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 2), + (T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 4), + (T.float32, T.bfloat16, "__float22bfloat162_rn", 2), + (T.float32, T.bfloat16, "__float22bfloat162_rn", 4), + (T.bfloat16, T.float32, "__bfloat1622float2", 2), + (T.bfloat16, T.float32, "__bfloat1622float2", 4), + (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2), + (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4), + (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2), + (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4), + ], +) +def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): + run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_view.py b/tilelang/original/testing/python/language/test_tilelang_language_view.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4c3711b26aa3f180089027ab55f3ad496339e5 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_view.py @@ -0,0 +1,86 @@ +import tilelang.language as T +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl +import pytest + + +def view_test(N, M, dtype, new_dtype=None): + new_shape = [N // M, M] + if new_dtype: + from tvm import DataType + + dtype_src = DataType(dtype) + dtype_dst = DataType(new_dtype) + src_bits = dtype_src.bits + dst_bits = dtype_dst.bits + scale = src_bits / dst_bits + new_shape[-1] = int(M * scale) + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + ): + with T.Kernel(1) as _: + A_viewed = T.view(A, new_shape, dtype=new_dtype) + T.copy(A_viewed, B) + + return main + + +def run_view(N, M, dtype, new_dtype=None): + program = view_test(N, M, dtype, new_dtype) + jit_kernel = tl.compile(program, out_idx=-1) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + if new_dtype: + torch_dtype = T.dtype(new_dtype).as_torch() + return A.view(N // M, M).view(dtype=torch_dtype) + return A.view(N // M, M) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reshape_view(): + # Test view with same dtype + run_view(1024, 32, T.float32) + run_view(2048, 64, T.float16) + + # Test view with dtype conversion + run_view(1024, 32, T.float32, T.float16) + run_view(2048, 64, T.float16, T.float32) + + +def view_shape_mismatch_test(N, M, dtype, new_dtype=None): + new_shape = [N // M, M + 1] + if new_dtype: + from tvm import DataType + + dtype_src = DataType(dtype) + dtype_dst = DataType(new_dtype) + src_bits = dtype_src.bits + dst_bits = dtype_dst.bits + scale = src_bits / dst_bits + new_shape[-1] = int(M * scale) + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + ): + with T.Kernel(1) as _: + A_viewed = T.view(A, new_shape, dtype=new_dtype) + T.copy(A_viewed, B) + + return main + + +def test_view_shape_mismatch(): + with pytest.raises(AssertionError): + view_shape_mismatch_test(1024, 32, T.float32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/language/test_tilelang_language_warp_reduce.py b/tilelang/original/testing/python/language/test_tilelang_language_warp_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..a8868013d24e32ba724ef5a76eb005d6aea42061 --- /dev/null +++ b/tilelang/original/testing/python/language/test_tilelang_language_warp_reduce.py @@ -0,0 +1,82 @@ +import torch + +import tilelang +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit +def get_kernel(reduce_op: str, dtype: str): + assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] + + @T.prim_func + def main(x: T.Tensor((32), dtype)): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding(0) + local_val = T.alloc_local([1], dtype) + local_val[0] = x[tx] + reduced_val = T.alloc_local([1], dtype) + if reduce_op == "sum": + reduced_val[0] = T.warp_reduce_sum(local_val[0]) + elif reduce_op == "max": + reduced_val[0] = T.warp_reduce_max(local_val[0]) + elif reduce_op == "min": + reduced_val[0] = T.warp_reduce_min(local_val[0]) + elif reduce_op == "bitand": + reduced_val[0] = T.warp_reduce_bitand(local_val[0]) + elif reduce_op == "bitor": + reduced_val[0] = T.warp_reduce_bitor(local_val[0]) + x[tx] = reduced_val[0] + + return main + + +def test_warp_reduce_sum(): + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("sum", T.float32) + ref = torch.full_like(a, a.sum()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_max(): + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("max", T.float32) + print(kernel.get_kernel_source()) + ref = torch.full_like(a, a.max()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_min(): + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("min", T.float32) + ref = torch.full_like(a, a.min()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_bitand(): + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") + kernel = get_kernel("bitand", T.int32) + ref_val = a[0] + for i in range(1, a.shape[0]): + ref_val = ref_val & a[i] + ref = torch.full_like(a, ref_val) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_bitor(): + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") + kernel = get_kernel("bitor", T.int32) + ref_val = a[0] + for i in range(1, a.shape[0]): + ref_val = ref_val | a[i] + ref = torch.full_like(a, ref_val) + kernel(a) + torch.testing.assert_close(a, ref) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/layout/test_tilelang_layout_fused_replicate.py b/tilelang/original/testing/python/layout/test_tilelang_layout_fused_replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa5f6c42ec0bed605caebdc52242aee8534cece --- /dev/null +++ b/tilelang/original/testing/python/layout/test_tilelang_layout_fused_replicate.py @@ -0,0 +1,62 @@ +import pytest +import torch + +import tilelang +import tilelang.testing +import tilelang.language as T + +tilelang.testing.set_random_seed() + +VEC_SIZE = 32 + + +@tilelang.jit +def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): + @T.prim_func + def main( + a: T.Buffer((B, M, N), T.bfloat16), + a_out: T.Buffer((B, M, N), T.float32), + ): + with T.Kernel( + T.ceildiv(M, BLOCK_MN), + T.ceildiv(N, BLOCK_K), + B, + threads=128, + ) as (pid_m, pid_n, pid_b): + a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), T.float32) + offs_m = pid_m * BLOCK_MN + offs_n = pid_n * BLOCK_K + + for i, j in T.Parallel(BLOCK_MN, BLOCK_K): + idx = i * BLOCK_K + j + a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] + + return main + + +def _require_cuda_tensor(shape, dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randn(*shape, device="cuda", dtype=dtype) + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +def test_layout_infer_compiles_and_runs(): + B, M, N = 1, 32, 64 + BLOCK_MN, BLOCK_K = 32, 64 + kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K) + + a = _require_cuda_tensor((B, M, N), torch.bfloat16) + a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device) + + # Ensure kernel compiles and executes without layout inversion failure + kernel(a, a_out) + + assert a_out.shape == a.shape + assert a_out.dtype == torch.float32 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/math/test_math_bitwise_reduce.py b/tilelang/original/testing/python/math/test_math_bitwise_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..044e0ea376c842e661afac288b77b3b39412e70b --- /dev/null +++ b/tilelang/original/testing/python/math/test_math_bitwise_reduce.py @@ -0,0 +1,114 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, +) +def bitwise_reduce( + M, + N, + block_M, + block_N, + name, + func, + clear=True, +): + @T.prim_func + def reduce_func( + A: T.Tensor((M, N), T.int32), + B: T.Tensor((M), T.int32), + Output: T.Tensor((M), T.int32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), T.int32) + A_fragment = T.alloc_fragment((block_M, block_N), T.int32) + B_shared = T.alloc_shared((block_M,), T.int32) + B_fragment = T.alloc_fragment((block_M), T.int32) + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(A_shared, A_fragment) + T.copy(B[by * block_M], B_shared) + T.copy(B_shared, B_fragment) + func(A_fragment, B_fragment, clear=clear) + T.copy(B_fragment, Output[by * block_M]) + + return reduce_func + + +def run_single_bitwise_reduce( + name, + func, + clear=True, +): + M, N = 32, 32 + block_M, block_N = 32, 32 + kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear) + + # Generate test data that exercises all bit patterns for robust bitwise reduce testing + a = torch.zeros((M, N), device="cuda", dtype=torch.int32) + + # Fill with patterns that will produce meaningful results for bitwise operations: + # - Different bit patterns across rows/columns + # - Mix of 0s and 1s in various positions + # - Some all-1s and all-0s patterns for edge cases + for i in range(M): + for j in range(N): + # Create varied bit patterns: + # Row-based pattern: alternating bits based on row index + row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row + + # Column-based pattern: different bit positions set based on column + col_pattern = 1 << (j % 31) # Single bit set at different positions + + # Combine patterns with XOR to create diverse bit distributions + # Add some deterministic "noise" based on position + position_factor = (i * N + j) % 256 + + # Final value combines all patterns + a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF + + if i % 4 == 0: + a[i, j] &= ~(0x1 << (i // 4)) + elif i % 2 == 0: + a[i, j] |= 0x1 << (i // 2) + + if name == "reduce_bitand": + expected = torch.full((M,), -1, device="cuda", dtype=torch.int32) + elif name == "reduce_bitor" or name == "reduce_bitxor": + expected = torch.full((M,), 0, device="cuda", dtype=torch.int32) + else: + raise ValueError("Invalid name: {}".format(name)) + + output = kernel(a, expected) + + for i in range(M): + for j in range(N): + if name == "reduce_bitand": + expected[i] = expected[i] & a[i, j] + elif name == "reduce_bitor": + expected[i] = expected[i] | a[i, j] + elif name == "reduce_bitxor": + expected[i] = expected[i] ^ a[i, j] + else: + raise ValueError("Invalid name: {}".format(name)) + assert torch.all(output == expected) + print("✓ {} with clear={} test passed".format(name, clear)) + + +@tilelang.testing.requires_cuda +def test_bitwise_reduce_ops(): + run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=True) + run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=True) + run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=True) + run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=False) + run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=False) + run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/math/test_math_fast_math.py b/tilelang/original/testing/python/math/test_math_fast_math.py new file mode 100644 index 0000000000000000000000000000000000000000..3c50e95f4c27ca3715cbf0088972e521e5f06740 --- /dev/null +++ b/tilelang/original/testing/python/math/test_math_fast_math.py @@ -0,0 +1,320 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import re + + +def get_mathop_lines(source, mathop_name): + """Extract lines containing the mathop from CUDA source for debugging""" + lines = source.split("\n") + relevant_lines = [] + for i, line in enumerate(lines): + if mathop_name in line and ("(" in line): + # Include some context + start = max(0, i - 1) + end = min(len(lines), i + 2) + relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) + relevant_lines.append("---") + return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output + + +def check_fastmath_usage(source, mathop_name, expect_fastmath=False): + """Check source for fastmath/non-fastmath versions""" + fastmath_pattern = rf"__({mathop_name}f?)\b" + non_fastmath_pattern = rf"(? 0: + print(f"Fastmath calls found: {fastmath_matches}") + if len(non_fastmath_matches) > 0: + print(f"Non-fastmath calls found: {non_fastmath_matches}") + print(f"Source preview for {mathop_name}:") + print(get_mathop_lines(source, mathop_name)) + + if expect_fastmath: + assert len(fastmath_matches) > 0, "Expected fastmath calls but found none" + print(f"✓ {mathop_name} correctly uses fastmath versions") + else: + assert len(fastmath_matches) == 0, f"Found unexpected fastmath calls: {fastmath_matches}" + assert len(non_fastmath_matches) > 0, f"No {mathop_name} calls found" + print(f"✓ {mathop_name} correctly uses non-fastmath versions") + + +def check_non_fastmath_usage(source, mathop_name): + """Check that source uses non-fastmath versions (no __ prefix)""" + check_fastmath_usage(source, mathop_name, expect_fastmath=False) + + +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test single-argument mathops. + T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) + + # Test with FAST_MATH disabled + kernel_no_fastmath = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + source_no_fastmath = kernel_no_fastmath.get_kernel_source() + + print(f"\n=== Testing {mathop_name} ===") + print("FAST_MATH=False:") + + # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) + check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) + + print(f"✓ {mathop_name} compilation and execution test passed") + + +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test two-argument mathops to ensure they generate non-fastmath CUDA code. + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j] + ) + + # Test with FAST_MATH disabled + kernel_no_fastmath = tilelang.compile( + main, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + # Test with FAST_MATH enabled + kernel_fastmath = tilelang.compile( + main, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + + source_no_fastmath = kernel_no_fastmath.get_kernel_source() + source_fastmath = kernel_fastmath.get_kernel_source() + + print(f"\n=== Testing {mathop_name} (two args) ===") + print("FAST_MATH=False:") + check_non_fastmath_usage(source_no_fastmath, mathop_name) + + print("FAST_MATH=True:") + check_non_fastmath_usage(source_fastmath, mathop_name) + + # Test numerical correctness + torch_dtype = dtype.as_torch() + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + b = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + # Ensure positive values for functions that need them + if mathop_name == "pow": + a = torch.abs(a) + 0.1 + b = torch.clamp(b, -3, 3) # Limit exponent range + elif mathop_name == "fmod": + b = torch.abs(b) + 0.1 # Avoid division by zero + + c_no_fastmath = kernel_no_fastmath(a, b) + c_fastmath = kernel_fastmath(a, b) + + # Both should produce similar results + torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3) + print(f"✓ {mathop_name} numerical test passed") + + +def run_abs_test(): + """Test that abs correctly maps to fabs (not __fabsf) in generated CUDA code""" + M, N = 128, 128 + block_M, block_N = 32, 32 + + @T.prim_func + def main( + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = T.abs(A[by * block_M + i, bx * block_N + j]) + + kernel = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + source = kernel.get_kernel_source() + print("\n=== Testing abs (maps to fabs) ===") + check_non_fastmath_usage(source, "fabs") + + # Test numerical correctness + a = torch.randn(M, N, device="cuda", dtype=torch.float32) + b = kernel(a) + expected = torch.abs(a) + + torch.testing.assert_close(b, expected, rtol=1e-5, atol=1e-5) + print("✓ abs numerical test passed") + + +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) + + # Test with FAST_MATH enabled + kernel_fastmath = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + + source_fastmath = kernel_fastmath.get_kernel_source() + + print(f"\n=== Testing {mathop_name} (fastmath version) ===") + print("FAST_MATH=True:") + # Strip the __ prefix for checking in the CUDA source + cuda_mathop_name = mathop_name.lstrip("_") + check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) + + # Test numerical correctness + torch_dtype = dtype.as_torch() + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + # Ensure positive values for functions that need them + if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: + a = torch.abs(a) + 0.1 + + b_fastmath = kernel_fastmath(a) + + # Compare with reference implementation + if cuda_mathop_name == "exp": + expected = torch.exp(a) + elif cuda_mathop_name == "log": + expected = torch.log(a) + else: + expected = b_fastmath # Just check compilation works + + torch.testing.assert_close(b_fastmath, expected, rtol=1e-3, atol=1e-3) + print(f"✓ {mathop_name} numerical test passed") + + +@tilelang.testing.requires_cuda +def test_mathops_generate_no_fastmath(): + """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" + # Based on test results, our tl.* intrinsics actually generate + # no fastmath versions + # This appears to be the intended behavior + single_arg_mathops = [ + ("exp", T.exp), + ("exp2", T.exp2), + ("exp10", T.exp10), + ("log", T.log), + ("log2", T.log2), + ("log10", T.log10), + ("sin", T.sin), + ("cos", T.cos), + ("tan", T.tan), + ("sinh", T.sinh), + ("cosh", T.cosh), + ("tanh", T.tanh), + ("atan", T.atan), + ("sqrt", T.sqrt), + ("rsqrt", T.rsqrt), + ("erf", T.erf), + ("floor", T.floor), + ("ceil", T.ceil), + ("trunc", T.trunc), + ("round", T.round), + ("nearbyint", T.nearbyint), + ] + + for name, func in single_arg_mathops: + run_single_arg_mathop_test(name, func, dtype=T.float32) + print(f"✓ {name} test passed") + + +@tilelang.testing.requires_cuda +def test_two_arg_mathops_fastmath(): + """Test all two-argument mathops""" + # Two argument mathops + two_arg_mathops = [ + ("pow", T.pow), + ("fmod", T.fmod), + ] + + for name, func in two_arg_mathops: + run_two_arg_mathop_test(name, func, dtype=T.float32) + + +@tilelang.testing.requires_cuda +def test_abs_maps_to_fabs(): + """Test that abs correctly maps to fabs""" + run_abs_test() + + +@tilelang.testing.requires_cuda +def test_fastmath_versions(): + """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" + # Test fastmath versions + fastmath_mathops = [ + ("__exp", T.__exp), + ("__exp10", T.__exp10), + ("__log", T.__log), + ("__log2", T.__log2), + ("__log10", T.__log10), + ("__tan", T.__tan), + ("__cos", T.__cos), + ("__sin", T.__sin), + ] + + for name, func in fastmath_mathops: + run_fastmath_mathop_test(name, func, dtype=T.float32) + print(f"✓ {name} test passed") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/math/test_math_ieee_math.py b/tilelang/original/testing/python/math/test_math_ieee_math.py new file mode 100644 index 0000000000000000000000000000000000000000..5d49880027e3c61ef4e597c5e3168a699e73fd81 --- /dev/null +++ b/tilelang/original/testing/python/math/test_math_ieee_math.py @@ -0,0 +1,230 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import pytest + + +def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype=T.float32): + """ + Test IEEE-compliant math operations with specified rounding modes. + """ + + # Define the appropriate function based on operation type to avoid TVM parsing conflicts + if mathop_name == "ieee_fmaf": + + @T.prim_func + def main_func( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + D: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + D[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], + B[by * block_M + i, bx * block_N + j], + C[by * block_M + i, bx * block_N + j], + rounding_mode, + ) + + out_idx = [3] + num_inputs = 3 + elif mathop_name in ["ieee_add", "ieee_sub", "ieee_mul", "ieee_fdiv"]: + + @T.prim_func + def main_func( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j], rounding_mode + ) + + out_idx = [2] + num_inputs = 2 + else: # Single argument operations + + @T.prim_func + def main_func( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], rounding_mode) + + out_idx = [1] + num_inputs = 1 + + # Test compilation + kernel = tilelang.compile( + main_func, + out_idx=out_idx, + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===") + print(f"✓ {mathop_name} compilation test passed") + + # Test numerical execution + torch_dtype = dtype.as_torch() + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + if num_inputs >= 2: + b = torch.randn(M, N, device="cuda", dtype=torch_dtype) + if num_inputs == 3: + c = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + # Ensure positive values for functions that need them + if mathop_name in ["ieee_frcp", "ieee_fsqrt"]: + a = torch.abs(a) + 0.1 + elif mathop_name == "ieee_fdiv": + b = torch.abs(b) + 0.1 # Avoid division by zero + + # Execute kernel + try: + if num_inputs == 1: + result = kernel(a) + elif num_inputs == 2: + result = kernel(a, b) + else: # num_inputs == 3 + result = kernel(a, b, c) + + assert result is not None + print(f"✓ {mathop_name} numerical execution test passed") + except Exception as e: + print(f"Warning: {mathop_name} execution failed: {e}") + + +def test_rounding_mode_validation(): + """Test that invalid rounding modes raise ValueError""" + + # Test with invalid rounding mode + with pytest.raises(ValueError, match="Invalid rounding mode"): + T.ieee_add(1.0, 2.0, "invalid_mode") + + with pytest.raises(ValueError, match="Invalid rounding mode"): + T.ieee_mul(1.0, 2.0, "xy") + + with pytest.raises(ValueError, match="Invalid rounding mode"): + T.ieee_fsqrt(4.0, "bad_mode") + + print("✓ Rounding mode validation test passed") + + +@tilelang.testing.requires_cuda +def test_ieee_add_all_rounding_modes(): + """Test IEEE addition with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_add", T.ieee_add, rounding_mode=mode) + print(f"✓ ieee_add with {mode} passed") + + +@tilelang.testing.requires_cuda +def test_ieee_sub_all_rounding_modes(): + """Test IEEE subtraction with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_sub", T.ieee_sub, rounding_mode=mode) + print(f"✓ ieee_sub with {mode} passed") + + +@tilelang.testing.requires_cuda +def test_ieee_mul_all_rounding_modes(): + """Test IEEE multiplication with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_mul", T.ieee_mul, rounding_mode=mode) + print(f"✓ ieee_mul with {mode} passed") + + +@tilelang.testing.requires_cuda +def test_ieee_fmaf_all_rounding_modes(): + """Test IEEE fused multiply-add with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_fmaf", T.ieee_fmaf, rounding_mode=mode) + print(f"✓ ieee_fmaf with {mode} passed") + + +@tilelang.testing.requires_cuda +def test_ieee_frcp_all_rounding_modes(): + """Test IEEE reciprocal with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_frcp", T.ieee_frcp, rounding_mode=mode) + print(f"✓ ieee_frcp with {mode} passed") + + +@tilelang.testing.requires_cuda +def test_ieee_fsqrt_all_rounding_modes(): + """Test IEEE square root with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_fsqrt", T.ieee_fsqrt, rounding_mode=mode) + print(f"✓ ieee_fsqrt with {mode} passed") + + +@tilelang.testing.requires_cuda +def test_ieee_frsqrt_rn_only(): + """Test IEEE reciprocal square root (round to nearest only)""" + + @T.prim_func + def main( + A: T.Tensor((128, 128), T.float32), + B: T.Tensor((128, 128), T.float32), + ): + with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by): + for i, j in T.Parallel(32, 32): + B[by * 32 + i, bx * 32 + j] = T.ieee_frsqrt(A[by * 32 + i, bx * 32 + j]) + + kernel = tilelang.compile( + main, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + }, + ) + + print("\n=== Testing ieee_frsqrt (rn only) ===") + print("✓ ieee_frsqrt compilation test passed") + + # Test numerical execution + a = torch.abs(torch.randn(128, 128, device="cuda", dtype=torch.float32)) + 0.1 + + try: + result = kernel(a) + assert result is not None + print("✓ ieee_frsqrt numerical execution test passed") + except Exception as e: + print(f"Warning: ieee_frsqrt execution failed: {e}") + + +@tilelang.testing.requires_cuda +def test_ieee_fdiv_all_rounding_modes(): + """Test IEEE division with all rounding modes""" + rounding_modes = ["rn", "rz", "ru", "rd"] + + for mode in rounding_modes: + run_ieee_math_test("ieee_fdiv", T.ieee_fdiv, rounding_mode=mode) + print(f"✓ ieee_fdiv with {mode} passed") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/metal/test_metal_codegen.py b/tilelang/original/testing/python/metal/test_metal_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..5349bbec58e6dad10066d7447b4462a50426e6f1 --- /dev/null +++ b/tilelang/original/testing/python/metal/test_metal_codegen.py @@ -0,0 +1,82 @@ +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +import torch + + +@tilelang.jit(execution_backend="torch") +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float32, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) + T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) + + for i, j, k in T.Parallel(block_M, block_N, block_K): + C_local[i, j] += A_shared[i, k] * B_shared[k, j] + + T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2) + + return gemm + + +def assert_gemm( + M, + N, + K, + block_M, + block_N, + block_K, + dtype=T.float32, + accum_dtype=T.float32, + atol=1e-8, +): + jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + + torch_dtype = dtype.as_torch() + a, b = None, None + if "int" in dtype: + a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps") + b = torch.randint(100, (K, N), dtype=torch_dtype, device="mps") + else: + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_dtype, device="mps") + + jit_kernel(a, b, c) + + assert torch.allclose(a @ b, c, atol=atol) + + assert jit_kernel.kernel_source is not None + + +@tilelang.testing.requires_metal +def test_gemm_float32(): + assert_gemm(1024, 1024, 1024, 16, 16, 16) + + +@tilelang.testing.requires_metal +def test_gemm_float16(): + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.float16, atol=1) + + +@tilelang.testing.requires_metal +def test_gemm_int32(): + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.int32, atol=1) + + +if __name__ == "__main__": + if torch.mps.is_available(): + tilelang.testing.main() diff --git a/tilelang/original/testing/python/profiler/test_tilelang_profiler.py b/tilelang/original/testing/python/profiler/test_tilelang_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..09d894c599ad84362645628abf0a651a8d73c0ad --- /dev/null +++ b/tilelang/original/testing/python/profiler/test_tilelang_profiler.py @@ -0,0 +1,54 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def test_profiler(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + ref_c = a @ b + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + # benchmark + profiler = kernel.get_profiler() + + # use cupti backend + cupti_latency = profiler.do_bench(backend="cupti") + + # use event backend + event_latency = profiler.do_bench(backend="event") + print(f"cupti Latency: {cupti_latency}ms") + print(f"event Latency: {event_latency}ms") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py b/tilelang/original/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..083373eb7b32ba7f29099daafaff4bf133230797 --- /dev/null +++ b/tilelang/original/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py @@ -0,0 +1,52 @@ +import pytest +import torch + +import tilelang +import tilelang.language as T +import tilelang.testing + + +@tilelang.jit +def dynamic_smem_kernel(): + # Symbolic length to drive dynamic shared memory allocation + length = T.symbolic("len", dtype=T.int32) # noqa: F821 + + @T.prim_func + def main(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821 + # Launch a simple kernel that copies from global memory into shared memory + # using a dynamically-sized allocation. No writes back to global_tensor. + with T.Kernel(1, threads=32) as _: + buffer_shared = T.alloc_shared((length,), dtype=T.int32) # noqa: F821 + T.copy(buffer_shared, global_tensor) + + return main + + +def _require_cuda_tensor(shape, dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randint(0, 100, shape, dtype=dtype, device="cuda") + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +def _run_and_check(kernel, n): + a = _require_cuda_tensor((n,), torch.int32) + kernel(a) + torch.cuda.synchronize() + + +def test_dynamic_shared_memory_varies_across_calls(): + kernel = dynamic_smem_kernel() + + # Run with different dynamic shared memory sizes across invocations + _run_and_check(kernel, 100) + _run_and_check(kernel, 200) + # Repeat sizes to exercise attribute caching path + _run_and_check(kernel, 200) + _run_and_check(kernel, 100) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..67123cb8c0f57e0122ecc16941120bed75204003 --- /dev/null +++ b/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -0,0 +1,540 @@ +import tilelang.language as T +from tilelang import tvm as tvm +import tilelang.testing +import pytest + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_ss( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +@pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved") +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + ], +) +def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +@pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved") +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + ], +) +def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128), + (128, 128, 32, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 32, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 32, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 32, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + ], +) +def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float, 128, 256, 32, 2, 128), + (128, 8, 128, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 2, 128), + (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 32, 2, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + ], +) +def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f4a29c9216d0f3eddc384a3ab070cfcb44afcf --- /dev/null +++ b/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -0,0 +1,357 @@ +import pytest +import torch +import tilelang +import tilelang.testing +import tilelang.language as T + +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.tensor import torch_assert_close, map_torch_type +from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter + +torch.backends.cuda.matmul.allow_tf32 = False +# torch.manual_seed(42) # only enable when debugging + + +def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): + is_8bit = "8" in in_dtype + is_unsigned = "uint" in in_dtype + is_int = "int" in in_dtype + if is_int: + if is_8bit: + low, high = (0, 4) if is_unsigned else (-2, 2) + else: + low, high = (0, 128) if is_unsigned else (-64, 64) + A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") + else: + A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype)) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) + return A, B + + +def matmul_sp_sm90( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + E_factor = 4 if in_dtype == T.float32 else 8 + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8") + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), + } + ) + T.disable_warp_group_reg_alloc() + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + is_8_bit = "8" in in_dtype + metadata_dtype = T.int32 if is_8_bit else T.int16 + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def normalize(tensor, max_range=100.0): + assert max_range <= 448.0 + max_v = tensor.abs().max().clamp(1e-4) + scaler = max_range / max_v + return tensor * scaler + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def run_gemm_sp( + kernel, + M, + N, + K, + in_dtype, + out_dtype, + block_K, + trans_A, + trans_B, +): + kernel = tilelang.compile( + kernel, + out_idx=[-1], + ) + A, B = generate_dense_input( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + ) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) + + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + if "float8" in in_dtype or "int8" in in_dtype: + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + if "float8" in in_dtype: + diff = calc_diff(C_sp, C) + assert diff < 1e-3, f"{diff=}" + else: + torch_assert_close( + C_sp.to(torch.float32), + C.to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def run_gemm_sp_sm90( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_M, + block_N, + block_K, + num_stages, + num_threads, + trans_A, + trans_B, +): + kernel = matmul_sp_sm90( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + trans_A, + trans_B, + ) + run_gemm_sp( + kernel, + M, + N, + K, + in_dtype, + out_dtype, + block_K, + trans_A, + trans_B, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(8, 0) +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def run_gemm_sp_sm80( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_M, + block_N, + block_K, + num_stages, + num_threads, + trans_A, + trans_B, +): + kernel = matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + trans_A, + trans_B, + ) + run_gemm_sp( + kernel, + M, + N, + K, + in_dtype, + out_dtype, + block_K, + trans_A, + trans_B, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", + [ + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 0, 256, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, T.float8_e4m3fn, T.float16, T.float16, 64, 64, 64, 2, 128, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True), + ], +) +def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): + run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(8, 0) +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", + [ + (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 32, 0, 32, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 64, 0, 32, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 1, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 3, 128, False, False), + (512, 1024, 768, T.int8, T.int32, T.int32, 32, 32, 64, 0, 32, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 0, 32, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 128, 128, 128, 0, 128, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 1, 128, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True), + ], +) +def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): + run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9d232902c68cdce11eed26ddd9cd9874cff9e25d --- /dev/null +++ b/tilelang/original/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -0,0 +1,633 @@ +import pytest +from tilelang import tvm as tvm +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse +from tilelang.utils.tensor import torch_assert_close, map_torch_type +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter + +import tilelang.testing +import torch +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_ss( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): + is_8bit = "8" in in_dtype + is_unsigned = "uint" in in_dtype + is_int = "int" in in_dtype + if is_int: + if is_8bit: + low, high = (0, 4) if is_unsigned else (-2, 2) + else: + low, high = (0, 128) if is_unsigned else (-64, 64) + A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") + else: + A = randn_semi_sparse(M, K, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) + return A, B + + +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (128, 8, 64, False, True, T.float16, T.float16, T.float, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int32, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), + ], +) +def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_sp_v2(A_frag, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), + ], +) +def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_sp_v2(A_shared, E_shared, B_frag, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), + ], +) +def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_sp_v2(A_frag, E_shared, B_frag, C_frag, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(map_torch_type(out_dtype)).to(torch.float32), + C.to(map_torch_type(out_dtype)).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + print("pass") + + +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128), + (128, 8, 128, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128), + (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 64, 2, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), + ], +) +def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_nullable_buffer_params.py b/tilelang/original/testing/python/transform/test_nullable_buffer_params.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbde254b5adab8d27de1708d37460a748f0bcd9 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_nullable_buffer_params.py @@ -0,0 +1,73 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +def test_nullable_shared_shape(): + """Test that buffers sharing a shape variable can be nullable.""" + + @tilelang.jit + def get_kernel(): + m = T.dynamic("m") + + @T.prim_func + def test_kernel( + a: T.Tensor[(m,), T.int32], + b: T.Tensor[(m,), T.int32], + c: T.Tensor[(m,), T.int32], + ): + with T.Kernel(1, threads=64): + tx = T.get_thread_binding() + if tx == 0: + T.print(m) + + return test_kernel + + m = 200 + kernel = get_kernel() + + # Create test tensors + tensor_a = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + tensor_b = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + tensor_c = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + + print("Test 1: All tensors provided") + kernel(tensor_a, tensor_b, tensor_c) + print("✓ PASS: All tensors provided") + + print("\nTest 2: Only first tensor provided") + kernel(tensor_a, None, None) + print("✓ PASS: Only first tensor provided") + + print("\nTest 3: Only middle tensor provided") + kernel(None, tensor_b, None) + print("✓ PASS: Only middle tensor provided") + + print("\nTest 4: Only last tensor provided") + kernel(None, None, tensor_c) + print("✓ PASS: Only last tensor provided") + + print("\nTest 5: First and last tensors provided") + kernel(tensor_a, None, tensor_c) + print("✓ PASS: First and last tensors provided") + + print("\nTest 6: All tensors are None (should fail)") + try: + kernel(None, None, None) + print("✗ FAIL: Should have raised an error") + return False + except RuntimeError as e: + if "at least one non-null buffer" in str(e): + print(f"✓ PASS: Correctly rejected with error: {e}") + else: + print(f"✗ FAIL: Wrong error message: {e}") + return False + + print("\n" + "=" * 60) + print("All tests passed!") + return True + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_readonly_param_const_codegen.py b/tilelang/original/testing/python/transform/test_readonly_param_const_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..0d255b46b4f083565b7b7e52347ce5841d400990 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_readonly_param_const_codegen.py @@ -0,0 +1,54 @@ +import tilelang.language as T +from tilelang.engine.lower import lower +from tilelang.jit.adapter.utils import match_declare_kernel + + +def _simple_add_kernel(): + @T.prim_func + def main( + x: T.Tensor((128,), T.float32), + y: T.Tensor((128,), T.float32), + ): + # One-dimensional kernel; writes y from x without modifying x + with T.Kernel(128, threads=32) as pid: + y[pid] = x[pid] + 1.0 + + return main + + +def test_codegen_emits_const_for_readonly_params(): + # Lower without device compilation to retrieve CUDA source reliably + func = _simple_add_kernel() + artifact = lower(func, target="cuda", enable_device_compile=False) + + src = artifact.kernel_source + print(src) + assert 'extern "C" __global__' in src + + # Extract kernel signature and check qualifiers + lparen = match_declare_kernel(src) + rparen = src.find(")", lparen) + assert rparen != -1 + signature = src[lparen:rparen] + + # x is read-only: should be `const` and `__restrict__` + assert "const float* __restrict__" in signature + # y is written: must not be const, but still `__restrict__` due to noalias + # We ensure there is a non-const float* parameter with __restrict__ as well + assert "const float* __restrict__ x" in src or "const float *__restrict__ x" in src + assert " float* __restrict__ y" in src or " float *__restrict__ y" in src + + # Also validate the function attribute carries read-only param indices + # Expect only the first handle parameter (x) to be marked read-only + device_mod = artifact.device_mod + prim_funcs = [f for f in device_mod.functions.values() if hasattr(f, "attrs")] + assert prim_funcs, "No PrimFunc found in device module" + pf = prim_funcs[0] + ro = pf.attrs.get("tl.readonly_param_indices") + assert ro is not None, "Expected tl.readonly_param_indices to be present" + ro_list = [int(i) for i in ro] + assert 0 in ro_list and 1 not in ro_list + + +if __name__ == "__main__": + test_codegen_emits_const_for_readonly_params() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/tilelang/original/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..cdff6fb1d32ac50286f766c291d131832fcd13e0 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -0,0 +1,51 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tl.transform.Simplify()(mod) + mod = tl.transform.LowerOpaqueBlock()(mod) + mod = tl.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) + + +def test_trival_pipeline(): + @T.prim_func + def before(A: T.Tensor((16, 1), T.float32), C: T.Tensor((16, 1), T.float32)): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + @T.prim_func + def expected(A_handle: T.handle, C_handle: T.handle): + A = T.match_buffer(A_handle, (16, 1), strides=(1, 1)) + C = T.match_buffer(C_handle, (16, 1), strides=(1, 1)) + tx = T.launch_thread("threadIdx.x", 16) + B = T.decl_buffer((2, 16, 1), scope="shared") + B[0, tx, 0] = A[tx, 0] * T.float32(2.0) + for i in range(0): + B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0) + C[tx, i] = B[i, tx, 0] + T.float32(1.0) + C[tx, 0] = B[0, tx, 0] + T.float32(1.0) + + _check(before, expected) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_cluster_planning.py b/tilelang/original/testing/python/transform/test_tilelang_transform_cluster_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..296c6ce947243d7a78f825950b8e4b46c2c3a0ad --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_cluster_planning.py @@ -0,0 +1,63 @@ +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tl.transform.ClusterPlanning()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tvm.tir.transform.LowerOpaqueBlock()(transformed) + + tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +def test_cluster_planning(): + @T.prim_func + def before(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)): + with T.Kernel(8, 8, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 32), T.float16) + B_shared = T.alloc_shared((32, 128), T.float16) + C_local = T.alloc_fragment((128, 128), T.float32) + + T.clear(C_local) + + for ko in T.Pipelined(32, num_stages=3): + T.copy(A[by * 128, ko * 32], A_shared) + T.copy(B[ko * 32, bx * 128], B_shared) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * 128, bx * 128]) + + @T.prim_func + def after(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)): + T.func_attr({"clusterIdx.y": T.int32(2)}) + with T.Kernel(8, 8, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 32), T.float16) + B_shared = T.alloc_shared((32, 128), T.float16) + C_local = T.alloc_fragment((128, 128), T.float32) + + T.clear(C_local) + + for ko in T.Pipelined(32, num_stages=3): + T.copy(A[by * 128, ko * 32], A_shared) + T.copy(B[ko * 32, bx * 128], B_shared) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * 128, bx * 128]) + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py b/tilelang/original/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py new file mode 100644 index 0000000000000000000000000000000000000000..559b2ffb4392fe4449ed3e03357ead2293f30a16 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py @@ -0,0 +1,167 @@ +import math + +import tilelang +import tilelang.language as T + + +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 0 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + + batch = T.int32(batch) + heads = T.int32(heads) + seq_len = T.int32(seq_len) + dim = T.int32(dim) + downsample_len = T.int32(downsample_len) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = T.bfloat16 + accum_dtype = T.float32 + block_mask_dtype = "bool" + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.Tensor([block_M, dim], dtype), + K_shared: T.Tensor([block_N, dim], dtype), + acc_s: T.Tensor([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.Tensor([block_M, dim], dtype), + acc_s_cast: T.Tensor([block_M, block_N], dtype), + acc_o: T.Tensor([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Tensor([block_M, block_N], accum_dtype), + acc_s_cast: T.Tensor([block_M, block_N], dtype), + scores_max: T.Tensor([block_M], accum_dtype), + scores_max_prev: T.Tensor([block_M], accum_dtype), + scores_scale: T.Tensor([block_M], accum_dtype), + scores_sum: T.Tensor([block_M], accum_dtype), + logsum: T.Tensor([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Tensor([block_M, dim], accum_dtype), + scores_scale: T.Tensor([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + + +def test_sta_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 24, 82944, 128 + + # Create sparse mask (downsampled to block level) + tile_size = (4, 8, 8) + BLOCK = tile_size[0] * tile_size[1] * tile_size[2] + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = tilelang.compile(program, out_idx=[4], pass_configs={"tl.config_index_bitwidth": 64}) + + cuda_source = kernel.get_kernel_source() + + assert "int64_t" in cuda_source + + +if __name__ == "__main__": + test_sta_attention() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/tilelang/original/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..533a62fc683332c4f13ba29db9716c26a45374e1 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -0,0 +1,206 @@ +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing +from tvm import tir + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.InjectFenceProxy()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tir.transform.LowerOpaqueBlock()(transformed) + + tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +def test_lower_fence_proxy(): + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.decl_buffer((32,), scope="local") + for i in T.unroll(16): + C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) + T.call_intrin( + "handle", + tir.op.Op.get("tl.tl_gemm"), + "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + + @T.prim_func + def after(): + with T.Kernel(8): + A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.decl_buffer((32,), scope="local") + for i in T.unroll(16): + C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) + T.fence_proxy_async() + T.call_intrin( + "handle", + tir.op.Op.get("tl.tl_gemm"), + "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + + _check(before, after) + + +def test_async_to_generic_no_double_fence(): + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn") + B_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn") + T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16) + T.fence_proxy_async() + T.call_extern("handle", "generic_op") + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.InjectFenceProxy()(mod) + + def _count_fences(stmt): + count = 0 + + def visit(node): + nonlocal count + if isinstance(node, tir.Evaluate): + call = node.value + if isinstance(call, tir.Call): + op = call.op + name = getattr(op, "name", None) + if name == "tl.fence_proxy_async": + count += 1 + + tir.stmt_functor.post_order_visit(stmt, visit) + return count + + assert _count_fences(mod["main"].body) == 1 + + +def test_proxy_hint_override(): + @T.prim_func + def before(): + with T.Kernel(8): + T.evaluate(T.call_extern("handle", "custom_async")) + with T.attr("proxy_scope", "tl.proxy_hint", "neutral"): + T.evaluate(T.call_extern("handle", "custom_generic")) + T.evaluate(T.call_extern("handle", "custom_async_tail")) + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.InjectFenceProxy()(mod) + + def _has_fence(stmt): + result = False + + def visit(node): + nonlocal result + if isinstance(node, tir.Evaluate): + call = node.value + if isinstance(call, tir.Call): + op = call.op + name = getattr(op, "name", None) + if name == "tl.fence_proxy_async": + result = True + + tir.stmt_functor.post_order_visit(stmt, visit) + return result + + assert not _has_fence(mod["main"].body) + + +def test_tma_store_sync_injection(): + @T.prim_func + def before(): + with T.Kernel(8): + A_global = T.decl_buffer((128,), T.float16, scope="global") + T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data)) + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.InjectFenceProxy()(mod) + + arrives = 0 + waits = 0 + + def visit(node): + nonlocal arrives, waits + if isinstance(node, tir.Evaluate): + call = node.value + if isinstance(call, tir.Call): + name = getattr(call.op, "name", None) + if name == "tl.tma_store_arrive": + arrives += 1 + elif name in ("tl.tma_store_wait", "tl.tma_store_wait<0>"): + waits += 1 + + tir.stmt_functor.post_order_visit(mod["main"].body, visit) + assert arrives == 1 + assert waits == 1 + + +def test_wgmma_marked_async(): + @T.prim_func + def before(): + with T.Kernel(1): + A_shared = T.decl_buffer((1,), T.float16, scope="shared") + desc_a = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma") + desc_b = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma") + C_local = T.decl_buffer((32,), T.float16, scope="local") + A_shared[0] = T.float16(0) + T.warpgroup_arrive() + T.ptx_wgmma_ss( + T.float16, + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc_a.data, + T.int32(0), + desc_b.data, + T.int32(0), + C_local.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.InjectFenceProxy()(mod) + order = [] + + def visit(node): + if isinstance(node, tir.Evaluate): + call = node.value + if isinstance(call, tir.Call): + order.append(getattr(call.op, "name", "")) + + tir.stmt_functor.post_order_visit(mod["main"].body, visit) + + assert "tl.ptx_wgmma_ss" in order + assert "tl.fence_proxy_async" in order + assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py b/tilelang/original/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py new file mode 100644 index 0000000000000000000000000000000000000000..1885c7c4b31a3d493fc8ab5c44be28db4bc27554 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py @@ -0,0 +1,139 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing +from tvm import tir + + +def test_inject_set_max_nreg(): + """Test the InjectSetMaxNReg pass""" + + @T.prim_func + def before(A: T.Tensor((512, 512), T.float16), B: T.Tensor((512, 512), T.float16)): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + v = T.launch_thread("threadIdx.x", 128) + + with T.block(""): + T.reads(A[by * 64, 0:512], B[0:512, bx * 64]) + T.writes() + + # Add set_max_nreg hints + T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24 + T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240 + + A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.alloc_buffer((32,), scope="local") + + T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128) + T.attr([128, 128], "kWarpSpecializationScope", 0) + + if v >= 128: + # Producer branch - should have set_max_nreg(24, 0) + for k in range(16): + T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1)) + if v - 128 == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) + else: + # Consumer branch - should have set_max_nreg(240, 1) + for k in range(16): + T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) + T.call_extern( + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) + + # Apply the InjectSetMaxNReg pass + func = before + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.AnnotateWarpGroupRegAlloc()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + + # Check that set_max_nreg calls are properly injected + main_func = mod["main"] + set_max_nreg_calls = [] + + def collect_set_max_nreg(stmt): + if ( + isinstance(stmt, tvm.tir.Evaluate) + and hasattr(stmt.value, "op") + and hasattr(stmt.value.op, "name") + and stmt.value.op.name == "tl.set_max_nreg" + ): + set_max_nreg_calls.append(stmt.value) + + tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) + + # We should have at least 2 set_max_nreg calls (one for producer, one for consumer) + assert len(set_max_nreg_calls) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}" + + print("InjectSetMaxNReg test passed!") + + +def test_inject_set_max_nreg_no_set_max_nreg(): + """Test the InjectSetMaxNReg pass with no_set_max_nreg""" + + @T.prim_func + def before_no_set_max_nreg(A: T.Tensor((512, 512), T.float16)): + bx = T.launch_thread("blockIdx.x", 8) + v = T.launch_thread("threadIdx.x", 128) + + with T.block(""): + T.reads(A[bx * 64, 0:64]) + T.writes() + + # Add no_set_max_nreg to disable register hinting + T.disable_warp_group_reg_alloc() + + T.create_list_of_mbarrier(128, 128) + T.attr([128, 128], "kWarpSpecializationScope", 0) + + if v >= 128: + # Producer branch - should not have set_max_nreg calls + T.evaluate(0) + else: + # Consumer branch - should not have set_max_nreg calls + T.evaluate(0) + + # Apply the InjectSetMaxNReg pass + func = before_no_set_max_nreg + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.AnnotateWarpGroupRegAlloc()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + + # Check that no set_max_nreg calls are injected when no_set_max_nreg is present + main_func = mod["main"] + set_max_nreg_calls = [] + + def collect_set_max_nreg(stmt): + if ( + isinstance(stmt, tvm.tir.Evaluate) + and hasattr(stmt.value, "op") + and hasattr(stmt.value.op, "name") + and stmt.value.op.name == "tl.set_max_nreg" + ): + set_max_nreg_calls.append(stmt.value) + + tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) + + # Should have no set_max_nreg calls when no_set_max_nreg is present + assert len(set_max_nreg_calls) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}" + + print("InjectSetMaxNReg with no_set_max_nreg test passed!") + + +if __name__ == "__main__": + # tilelang.testing.main() + test_inject_set_max_nreg() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_layout_inference.py b/tilelang/original/testing/python/transform/test_tilelang_transform_layout_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..82fcd19ab932e63f476f9fc20ee859a4f1bdc40a --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -0,0 +1,105 @@ +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target +import tilelang as tl +import tilelang.language as T +import tilelang.testing +import pytest + +auto_target = tvm.target.Target(determine_target("auto")) + + +@pytest.mark.parametrize( + "block_M, block_N, block_K, threads, vec_load_b, dtype", + [ + (64, 64, 32, 128, 8, T.float16), + ], +) +def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): + N = tvm.te.var("n") + K = tvm.te.var("k") + + def before(): + @T.prim_func + def main( + B: T.Tensor((K, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared((block_K, block_N), dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + t = thread_bindings + for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): + for vec in T.Parallel(vec_load_b): + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) + + return tvm.IRModule({"main": main}) + + def after(): + @T.prim_func + def main( + B: T.Tensor((K, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared((block_K, block_N), dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + t = thread_bindings + for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): + if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0: + for vec in T.vectorized(vec_load_b): + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) + else: + for vec in T.serial(vec_load_b): + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) + + return tvm.IRModule({"main": main}) + + with tvm.target.Target(auto_target): + mod = tvm.tir.transform.BindTarget(auto_target)(before()) + mod = tl.transform.LayoutInference()(mod) + mod = tvm.tir.transform.Simplify()(mod) + ref_mod = tvm.tir.transform.BindTarget(auto_target)(after()) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) + # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass + # This loop is "for vec in T.parallel(1)", + # Since the loop var "vec" is never used in the loop body, it does not affect the correctness + tvm.ir.structural_equal(mod, ref_mod) + # tvm.ir.assert_structural_equal(mod, ref_mod) + + +if __name__ == "__main__": + # tilelang.testing.main() + test_loop_tail_split(64, 64, 32, 128, 8, T.float16) diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_negative_index.py b/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_negative_index.py new file mode 100644 index 0000000000000000000000000000000000000000..26c151141f8132e17f81199d14838c9a0d9c9075 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_negative_index.py @@ -0,0 +1,342 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, expected): + """Helper function to verify structural equality after transformations""" + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.LegalizeNegativeIndex()(mod) + expected = tvm.IRModule.from_expr(expected.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], expected["main"], True) + + +def test_buffer_load_negative_index_legalized(): + """ + Test that negative indices are legalized by adding buffer extent. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + value = A[-1] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + value = A[1023] # A[-1] becomes A[1023] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + _check(before, after) + + +def test_buffer_load_mixed_negative_positive_indices(): + """ + Test mixed negative and positive indices - only negative ones are legalized. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), T.float32)): + value = A[-1, 10] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024, 512), T.float32)): + value = A[1023, 10] # A[-1, 10] becomes A[1023, 10] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + _check(before, after) + + +def test_buffer_load_multiple_negative_indices(): + """ + Test multiple negative indices in different dimensions. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512, 256), T.float32)): + value = A[-1, -2, -3] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024, 512, 256), T.float32)): + value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253 + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + _check(before, after) + + +def test_buffer_load_negative_index_in_expression(): + """ + Test negative index as part of a larger expression. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + B = T.alloc_buffer((1024,), T.float32) + for i in T.serial(1, 1024): + value = A[-i] + B[-i] = value + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + B = T.alloc_buffer((1024,), T.float32) + for i in T.serial(1, 1024): + value = A[1024 - i] + B[1024 - i] = value + + _check(before, after) + + +def test_buffer_load_non_negative_index_unchanged(): + """ + Test that non-negative indices remain unchanged. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + value = A[0] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + # No changes expected for non-negative indices + value = A[0] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + _check(before, after) + + +def test_buffer_load_unknown_sign_index_warning(): + """ + Test that indices with unknown sign trigger warnings but are processed. + This test mainly checks that the pass doesn't crash on unknown signs. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + i = T.Var("i", T.int32) + value = A[i] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + i = T.Var("i", T.int32) + # Unknown sign indices should remain unchanged + value = A[i] + B = T.alloc_buffer((1,), T.float32) + B[0] = value + + _check(before, after) + + +def test_buffer_load_vector_index_negative_broadcast(): + """ + Test negative indices in vectorized operations (broadcast case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + vec = T.Broadcast(-1, 4) + value = A[vec] + B = T.alloc_buffer((4,), T.float32) + B[T.Ramp(0, 1, 4)] = value + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + # vec is unused and can be delimed by Simplify. + vec = T.Broadcast(-1, 4) # noqa: F841 + value = A[T.Broadcast(1023, 4)] + B = T.alloc_buffer((4,), T.float32) + B[T.Ramp(0, 1, 4)] = value + + _check(before, after) + + +def test_buffer_load_vector_index_negative_ramp(): + """ + Test negative indices in vectorized operations (ramp case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] + value = A[vec] + B = T.alloc_buffer((4,), T.float32) + B[T.Ramp(0, 1, 4)] = value + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + # vec is unused and can be delimed by Simplify. + vec = T.Ramp(-4, 1, 4) # noqa: F841 + value = A[T.Ramp(1020, 1, 4)] + B = T.alloc_buffer((4,), T.float32) + B[T.Ramp(0, 1, 4)] = value + + _check(before, after) + + +def test_buffer_load_nested_buffer_loads(): + """ + Test legalization with nested buffer load expressions. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), T.float32)): + inner_val = A[-1, 10] + outer_val = A[inner_val.astype(T.int32), -2] + B = T.alloc_buffer((1,), T.float32) + B[0] = outer_val + + @T.prim_func + def after(A: T.Tensor((1024, 512), T.float32)): + inner_val = A[1023, 10] + outer_val = A[inner_val.astype(T.int32), 510] + B = T.alloc_buffer((1,), T.float32) + B[0] = outer_val + + _check(before, after) + + +def test_buffer_store_negative_index(): + """ + Test negative indices in buffer store operations are legalized. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + A[-1] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + A[1023] = 42.0 + + _check(before, after) + + +def test_buffer_store_mixed_negative_positive_indices(): + """ + Test mixed negative and positive indices in buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), T.float32)): + A[-1, 10] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024, 512), T.float32)): + A[1023, 10] = 42.0 + + _check(before, after) + + +def test_buffer_store_multiple_negative_indices(): + """ + Test multiple negative indices in different dimensions for buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512, 256), T.float32)): + A[-1, -2, -3] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024, 512, 256), T.float32)): + A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253 + + _check(before, after) + + +def test_buffer_store_negative_index_in_expression(): + """ + Test negative index as part of a larger expression in buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + for i in T.serial(1, 1024): + A[-i] = i * 2.0 + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + for i in T.serial(1, 1024): + A[1024 - i] = i * 2.0 + + _check(before, after) + + +def test_buffer_store_vector_index_negative_broadcast(): + """ + Test negative indices in vectorized store operations (broadcast case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + vec = T.Broadcast(-1, 4) + values = T.Broadcast(42.0, 4) + A[vec] = values + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + # vec is unused and can be delimed by Simplify. + vec = T.Broadcast(-1, 4) # noqa: F841 + values = T.Broadcast(42.0, 4) + A[T.Broadcast(1023, 4)] = values + + _check(before, after) + + +def test_buffer_store_vector_index_negative_ramp(): + """ + Test negative indices in vectorized store operations (ramp case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32)): + vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] + values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0] + A[vec] = values + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32)): + # vec is unused and can be delimed by Simplify. + vec = T.Ramp(-4, 1, 4) # noqa: F841 + values = T.Ramp(0.0, 1.0, 4) + A[T.Ramp(1020, 1, 4)] = values + + _check(before, after) + + +def test_buffer_store_nested_in_condition(): + """ + Test negative index buffer store within conditional statements. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), T.float32), flag: T.int32): + if flag > 0: + A[-1] = 42.0 + else: + A[-2] = 24.0 + + @T.prim_func + def after(A: T.Tensor((1024,), T.float32), flag: T.int32): + if flag > 0: + A[1023] = 42.0 + else: + A[1022] = 24.0 + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py new file mode 100644 index 0000000000000000000000000000000000000000..4f75fa05d00f1c001efffd6a309500f4205b1e76 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -0,0 +1,133 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): + dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N), dtype=dtype) + tid = T.get_thread_binding() + for j in T.serial(N): + A_shared[tid, j] = A[tid + M_offset, j + N_offset] + + @T.prim_func + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N), dtype=dtype) + tid = T.get_thread_binding() + + T.reads(A[tid + M_offset, N_offset : N + N_offset]) + for j in T.serial(N): + A_shared[tid, j] = T.if_then_else( + j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0) + ) + + return main, expected + + +def assert_vectorize_access(M: int = 64, N: int = 64): + func, expected = vectorize_access_legalize(M, N) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) + + +def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): + dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N), dtype=dtype) + tid = T.get_thread_binding() + for j in T.serial(N): + A_shared[tid, j] = A[tid + M_offset, j + N_offset] + T.atomic_add(A[tid + M_offset, j + N_offset], 1) + + @T.prim_func + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N), dtype=dtype) + tid = T.get_thread_binding() + + T.reads(A[tid + M_offset, N_offset : N + N_offset]) + for j in T.serial(N): + A_shared[tid, j] = T.if_then_else( + j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0) + ) + # Nest if-then-else is expected, do not flatten it to pass structural equal check + if j + N_offset < N: # noqa: SIM102 + if tid + M_offset < M: + T.call_extern("handle", "AtomicAdd", T.address_of(A[tid + M_offset, j + N_offset]), 1) + + return main, expected + + +def assert_vectorize_access_with_atmoic_add(M: int = 64, N: int = 64): + func, expected = vectorize_access_with_atmoic_add_legalize(M, N) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) + + +def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): + dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + tid = T.get_thread_binding() + for j in T.serial(N): + A[tid + M_offset, j + N_offset] = 1 + + @T.prim_func + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + tid = T.get_thread_binding() + T.writes(A[tid + M_offset, N_offset : N + N_offset]) + for j in T.serial(N): + if j + N_offset < N: # noqa: SIM102 + if tid + M_offset < M: + A[tid + M_offset, j + N_offset] = T.float32(1.0) + + return main, expected + + +def assert_oob_store_legalize(M: int = 64, N: int = 64): + func, expected = oob_store_legalize(M, N) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) + + +def test_vectorize_access(): + assert_vectorize_access(64, 64) + + +def test_vectorize_access_with_atmoic_add(): + assert_vectorize_access_with_atmoic_add(64, 64) + + +def test_oob_store(): + assert_oob_store_legalize(64, 64) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py b/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc7541ccfb64c33256f8f4a11b9702474c28b33 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py @@ -0,0 +1,49 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def vectorize_access_legalize(M: int = 64, N: int = 64): + dtype = T.float32 + vec_len = 8 + + @T.prim_func + def main( + A: T.Tensor((M, N, vec_len), dtype=T.float32), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) + tid = T.get_thread_binding() + for j in T.serial(N): + for v in T.vectorized(vec_len): + A_shared[tid, j, v] = A[tid, j, v] + + @T.prim_func + def expected( + A: T.Tensor((M, N, vec_len), dtype=T.float32), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) + tid = T.get_thread_binding() + for j, v_2 in T.grid(M, vec_len // 4): + for vec in T.vectorized(4): + A_shared[tid, j, v_2 * 4 + vec] = A[tid, j, v_2 * 4 + vec] + + return main, expected + + +def assert_vectorize_access(M: int = 64, N: int = 64): + func, expected = vectorize_access_legalize(M, N) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + with tvm.target.Target("cuda"): + transformed = tl.transform.LegalizeVectorizedLoop()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) + + +def test_vectorize_access(): + assert_vectorize_access(64, 64) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_let_inline.py b/tilelang/original/testing/python/transform/test_tilelang_transform_let_inline.py new file mode 100644 index 0000000000000000000000000000000000000000..e773e3feebe37cd5c2a3bc9c3950b0cb33ca59d4 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_let_inline.py @@ -0,0 +1,52 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.LetInline()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) + + +def test_let_binding(): + @T.prim_func + def before(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)): + for i in range(128): + for j in range(128): + with T.block("compute"): + factor = T.float32(2.0) + value = A[i, j] * factor + B[i, j] = value + + @T.prim_func + def expected(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)): + for i in range(128): + for j in range(128): + with T.block("compute"): + B[i, j] = A[i, j] * T.float32(2.0) + + _check(before, expected) + + +def test_parallel_scope(): + @T.prim_func + def before(A: T.Tensor((128,), T.float32)): + for i in T.Parallel(128): + with T.block("parallel"): + value = T.float32(1.0) + A[i] = value + + @T.prim_func + def expected(A: T.Tensor((128,), T.float32)): + for i in T.Parallel(128): + with T.block("parallel"): + A[i] = T.float32(1.0) + + _check(before, expected) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/tilelang/original/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py new file mode 100644 index 0000000000000000000000000000000000000000..f411b3d5b5779ec5b069deb5c6d9da4d1ba9ee26 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -0,0 +1,49 @@ +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing +from tvm import tir + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.LowerHopperIntrin()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tir.transform.LowerOpaqueBlock()(transformed) + transformed["main"] = transformed["main"].with_attr("tma_descriptor_args", {}) + + # TODO: temporary remove this check + # tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +def test_lower_hopper_intrin_barrier(): + @T.prim_func + def before(): + with T.Kernel(8): + _ = T.launch_thread("threadIdx.x", 128) + T.create_list_of_mbarrier(128, 128, 128, 128) + + @T.prim_func + def after(): + with T.Kernel(8): + v_1 = T.launch_thread("threadIdx.x", 128) + T.evaluate(tir.Call("handle", "tir.create_barriers", [4])) + with T.If(v_1 == 0), T.Then(): + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(0), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(1), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(2), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(3), 128])) + T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"])) + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/tilelang/original/testing/python/transform/test_tilelang_transform_lower_tile_op.py new file mode 100644 index 0000000000000000000000000000000000000000..16c7cb8027853b68df9bcb244e09e6613efe4255 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -0,0 +1,88 @@ +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target +import tilelang as tl +import tilelang.language as T +import tilelang.testing +import pytest + +auto_target = tvm.target.Target(determine_target("auto")) + + +@pytest.mark.parametrize( + "block_M, block_N, block_K, threads, vec_load_b, dtype", + [ + (64, 64, 32, 128, 8, T.float16), + ], +) +def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): + N = tvm.te.var("n") + K = tvm.te.var("k") + + def before(): + @T.prim_func + def main( + B: T.Tensor((K, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared((block_K, block_N), dtype) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(B[k * block_K, bx * block_N], B_shared) + + return tvm.IRModule({"main": main}) + + def after(): + @T.prim_func + def main( + B: T.Tensor((K, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared((block_K, block_N), dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + t = thread_bindings + for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): + if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0: + for vec in T.vectorized(vec_load_b): + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) + else: + for vec in T.serial(vec_load_b): + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) + + return tvm.IRModule({"main": main}) + + with tvm.transform.PassContext(): + mod = tvm.tir.transform.BindTarget(auto_target)(before()) + mod = tl.transform.LowerTileOp()(mod) + mod = tvm.tir.transform.Simplify()(mod) + ref_mod = tvm.tir.transform.BindTarget(auto_target)(after()) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) + # Note(tzj): The structures are equal except the argument in "T.reads" function. + # The difference is just between the first index and the indices range, which is totally equivalent + tvm.ir.structural_equal(mod, ref_mod) + # tvm.ir.assert_structural_equal(mod, ref_mod) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_make_packed_api.py b/tilelang/original/testing/python/transform/test_tilelang_transform_make_packed_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2508a9d12e4c844152322246fbd5f72c57acf4f2 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_make_packed_api.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ruff: noqa + +import pytest +import numpy as np +import tilelang +from tilelang import tvm as tvm +import tvm +import tilelang.testing +from tvm import tir +from tvm.script import tir as T, ir as I + + +def _find_assignment(stmt, var_name): + while not isinstance(stmt, tvm.tir.LetStmt): + stmt = stmt.body + + if stmt.var.name != var_name: + return _find_assignment(stmt.body, var_name) + + return stmt + + +def _find_compute_scope(func): + result = None + + def _visitor(stmt): + if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope": + nonlocal result + result = stmt + + tir.stmt_functor.post_order_visit(func.body, _visitor) + + return result + + +@pytest.mark.parametrize("use_global_symbol", [False]) +def test_no_op_when_global_symbol_is_absent(use_global_symbol): + func_attr = {"target": tvm.target.Target("llvm", host="llvm")} + + @T.prim_func(private=True) + def before(): + T.func_attr(func_attr) + T.evaluate(0) + + if use_global_symbol: + before = before.with_attr("global_symbol", "main") + + after = tilelang.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] + if use_global_symbol: + assert len(after.params) == 4 + else: + tvm.ir.assert_structural_equal(before, after) + + +def test_target_host_removed(): + """After MakePackedAPI, host-side target should be the host + + MakePackedAPI is the last transform that requires both the device + and the host. After MakePackedAPI, the target attribute should + only contain the host-side target. + """ + + host = tvm.target.Target("llvm") + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) + T.evaluate(0) + + after = tilelang.transform.MakePackedAPI()(before) + target_attr = after["main"].attrs["target"] + assert str(host) == str(target_attr) + + +def test_internal_subroutine_call(): + """Internal subroutines should not use the PackedFunc API + + A subroutine without the "global_symbol" attribute is an internal + subroutine, and is not directly exposed to a user of the generated + `runtime.Module`. Therefore, it doesn't need to follow the + PackedFunc API. + """ + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm", host="llvm")}) + before.subroutine(A.data) + + # this test fails if it's made public + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(A_data) + + after = tilelang.transform.MakePackedAPI()(before) + tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"]) + + compute_scope = _find_compute_scope(after["main"]) + subroutine_call_op = compute_scope.body.value.op + assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), ( + f"The main function's CallNode should use the subroutine's GLobalVar as the operation, " + f"but instead has an operation of type {subroutine_call_op}" + ) + + +def test_subroutine_call_to_externally_visible_subroutine(): + """Externally-visible subroutines should use the PackedFunc API + + Because the subroutine may be called directly by a user, it must + use the PackedFunc API. Its signature should be updated to the + PackedFunc signature, and call sites should be updated to use + `T.tvm_call_cpacked`. + """ + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) + before.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}) + T.evaluate(A_data) + + after = tilelang.transform.MakePackedAPI()(before) + + main_compute_scope = _find_compute_scope(after["main"]) + assert main_compute_scope is not None + subroutine_compute_scope = _find_compute_scope(after["subroutine"]) + assert subroutine_compute_scope is not None + + subroutine_call_op = main_compute_scope.body.value.op + assert isinstance(subroutine_call_op, tvm.ir.Op) and subroutine_call_op.name == "tir.tvm_call_cpacked", ( + f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " + f"but instead has an operation of type {subroutine_call_op}" + ) + + +@tilelang.testing.requires_llvm +def test_function_call_with_wrong_argument_count(): + """Argument counts must be checked before accessing the type codes""" + + @T.prim_func + def func( + A: T.Buffer([16, 16], "int32"), + B: T.Buffer([16, 16], "int32"), + C: T.Buffer([16, 16], "int32"), + D: T.Buffer([16, 16], "int32"), + ): + pass + + built = tvm.compile(func, target="llvm") + + with pytest.raises(tvm.TVMError): + built() + + +@tilelang.testing.requires_llvm +def test_function_call_with_wrong_type_code(): + """Type codes must be checked before accessing the arguments""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "int32")): + pass + + built = tvm.compile(func, target="llvm") + + with pytest.raises(tvm.TVMError): + built(0) + + +@tilelang.testing.requires_llvm +def test_function_call_with_null_data_pointer(): + """The data pointer must be checked before accessing the array""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): + for i, j in T.grid(16, 16): + B[i, j] = A[i, j] + + built = tvm.compile(func, target="llvm") + + A = tvm.nd.array(np.zeros([16], dtype="int32")) + B = tvm.nd.empty([16, 16], "int32", tvm.cpu()) + + with pytest.raises(tvm.TVMError): + built(A, B) + + +@tilelang.testing.requires_llvm +def test_function_call_with_wrong_dimensionality(): + """The dimensionality must be checked before validating the shape""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): + for i, j in T.grid(16, 16): + B[i, j] = A[i, j] + + built = tvm.compile(func, target="llvm") + + A = tvm.nd.array(np.zeros([16], dtype="int32")) + B = tvm.nd.empty([16], "int32", tvm.cpu()) + + with pytest.raises(tvm.TVMError): + built(A, B) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_multi_version_buffer.py b/tilelang/original/testing/python/transform/test_tilelang_transform_multi_version_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..e85fd8db8d1eff525b7e6ec8c1a0e9698717d445 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_multi_version_buffer.py @@ -0,0 +1,144 @@ +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing +from tvm import tir + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.MultiVersionBuffer()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tir.transform.LowerOpaqueBlock()(transformed) + + tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +M = 512 +N = 512 +K = 512 +dtype = T.float16 +block_M = 64 +block_N = 64 +block_K = 32 + + +def test_multi_version_buffer(): + @T.prim_func + def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + v = T.launch_thread("threadIdx.x", 128) + with T.block(""): + T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) + T.writes() + A_shared = T.alloc_buffer((1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.alloc_buffer((32,), scope="local") + for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(2): + C_local[i * 2 + vec] = T.float32(0) + for k in T.serial(16, annotations={"num_stages": T.int32(3)}): + if v == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 2), + k * 32, + by * 64, + ) + if v == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 2), + bx * 64, + k * 32, + ) + T.call_extern( + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + + @T.prim_func + def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + v = T.launch_thread("threadIdx.x", 128) + with T.block(""): + T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) + T.writes() + A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.alloc_buffer((32,), scope="local") + for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(2): + C_local[i * 2 + vec] = T.float32(0) + for k in T.serial(16, annotations={"num_stages": T.int32(3)}): + if v == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) + if v == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) + T.call_extern( + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + + _check(before, after) + + +def test_multi_version_buffer_with_let(): + @T.prim_func + def before(scales: T.Tensor((4,), T.float32)): + with T.block("root"): + shared = T.alloc_buffer((8,), T.float32, scope="shared.dyn") + accum = T.alloc_buffer((8,), T.float32, scope="local") + for k in T.serial(4, annotations={"num_stages": T.int32(2)}): + value = scales[k] + for i in T.serial(8): + shared[i] = value + for i in T.serial(8): + accum[i] = accum[i] + shared[i] + + @T.prim_func + def after(scales: T.Tensor((4,), T.float32)): + with T.block("root"): + shared = T.alloc_buffer((2, 8), T.float32, scope="shared.dyn") + accum = T.alloc_buffer((8,), T.float32, scope="local") + for k in T.serial(4, annotations={"num_stages": T.int32(2)}): + value = scales[k] + for i in T.serial(8): + shared[k % 2, i] = value + for i in T.serial(8): + accum[i] = accum[i] + shared[k % 2, i] + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_pipeline_planning.py b/tilelang/original/testing/python/transform/test_tilelang_transform_pipeline_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..83db7f75cf34e242aca124263f728de8e3b2944f --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_pipeline_planning.py @@ -0,0 +1,66 @@ +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.PipelinePlanning()(mod) + mod = tl.transform.Simplify()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +def test_simple_pipeline(): + @T.prim_func + def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): + with T.Kernel(8, 8, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 32), T.float32) + B_shared = T.alloc_shared((32, 128), T.float32) + C_local = T.alloc_fragment((128, 128), T.float32) + + T.clear(C_local) + + for ko in T.Pipelined(32, num_stages=3): + T.copy(A[by * 128, ko * 32], A_shared) + T.copy(B[ko * 32, bx * 128], B_shared) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * 128, bx * 128]) + + @T.prim_func + def after(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): + with T.Kernel(8, 8, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 32), T.float32) + B_shared = T.alloc_shared((32, 128), T.float32) + C_local = T.alloc_fragment((128, 128), T.float32) + + T.clear(C_local) + + for ko in T.serial( + 32, + annotations={ + "software_pipeline_async_stages": [T.int32(0)], + "software_pipeline_order": [T.int32(0), T.int32(1), T.int32(2)], + "software_pipeline_stage": [T.int32(3), T.int32(3), T.int32(3)], + }, + ): + T.copy(A[by * 128, ko * 32], A_shared) + T.copy(B[ko * 32, bx * 128], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * 128, bx * 128]) + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_simplify.py b/tilelang/original/testing/python/transform/test_tilelang_transform_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..3b737682040ab78b91bd1ee40f38b59e60cb4880 --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_simplify.py @@ -0,0 +1,91 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def modify( + with_B: bool = False, + with_bias: bool = False, +): + @T.prim_func + def main( + A: T.Tensor((64, 64)), + B: T.Tensor((64, 64)), + C: T.Tensor((64, 64)), + D: T.Tensor((64, 64)), + bias: T.Tensor((64, 64)), + ): + if with_B: + if with_bias: + T.gemm(A, bias, D) + T.gemm(A, B, D) + else: + with T.block(): + A_shared = T.alloc_shared((64, 64), dtype=T.float32) + C_shared = T.alloc_shared((64, 64), dtype=T.float32) + D_shared = T.alloc_shared((64, 64), dtype=T.float32) + T.copy(A, A_shared) + T.copy(C, C_shared) + T.gemm(A_shared, C_shared, D_shared) + T.copy(D_shared, D) + + return main + + +def test_modify(with_B=False, with_bias=False): + tester = modify(with_B=with_B, with_bias=with_bias) + mod = tvm.IRModule({tester.attrs["global_symbol"]: tester}) + mod2 = tl.transform.Simplify()(mod) + assert mod != mod2 + + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + a: T.handle, + b: T.handle, + c: T.handle, + ): + A = T.match_buffer(a, (M, K), dtype=dtype) + B = T.match_buffer(b, (K, N), dtype=dtype) + C = T.match_buffer(c, (M, N), dtype=accum_dtype) + + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def test_matmul(): + func = matmul(1024, 1024, 1024, 128, 128, 32) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + mod = tl.transform.Simplify()(mod) + kernel = tl.compile(mod["main"], out_idx=[2]) + + import torch + + a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() + b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() + c = kernel(a, b) + + ref_c = a @ b + ref_c = ref_c.float() + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + # Get CUDA Source + print(kernel.get_kernel_source()) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_thread_sync.py b/tilelang/original/testing/python/transform/test_tilelang_transform_thread_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..046ed447a49d0250ff29faac7ecd13c99e6be0ce --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -0,0 +1,224 @@ +# ruff: noqa + +from tilelang import tvm as tvm +import tilelang.testing +from tvm.script import tir as T +from tvm import te + + +def run_passes(func: tvm.tir.PrimFunc): + mod = tvm.IRModule.from_expr(func) + + cuda_target = tvm.target.Target("cuda", host="llvm") + + mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}))(mod) + + mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) + mod = tvm.tir.transform.SplitHostDevice()(mod) + return tilelang.transform.ThreadSync("shared")(mod) + + +@tilelang.testing.requires_cuda +def test_sync_if_with_same_index(): + @T.prim_func(check_well_formed=False) + def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + p0 = T.Buffer([2], dtype="float32", data=p0_arg.data) + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") + T.launch_thread(blockIdx_x, 8) + T.launch_thread(threadIdx_x, 4) + result_local[0] = T.float32(0) + if threadIdx_y < 8: + temp_shared[threadIdx_x] = p0[0] + temp_shared[threadIdx_x] = temp_shared[threadIdx_x] + result_local[0] = result_local[0] + temp_shared[0] + + mod = run_passes(func) + assert "T.tvm_storage_sync" in str(mod) + + +@tilelang.testing.requires_cuda +def test_sync_read_thread_id_independent_location(): + @T.prim_func + def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + p0 = T.Buffer([2], dtype="float32", data=p0_arg.data) + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") + T.launch_thread(blockIdx_x, 8) + T.launch_thread(threadIdx_x, 4) + result_local[0] = T.float32(0) + if threadIdx_x < 1: + temp_shared[0] = p0[0] + result_local[0] = result_local[0] + temp_shared[0] * p1[0] + if threadIdx_x < 1: + temp_shared[0] = p0[1] + result_local[0] = result_local[0] + temp_shared[0] * p1[1] + + mod = run_passes(func) + assert "T.tvm_storage_sync" in str(mod) + + +@tilelang.testing.requires_cuda +def test_sync_shared(): + @T.prim_func(private=True) + def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + B = T.allocate([24], "float32", "shared") + C = T.allocate([1], "float32", "local") + D = T.allocate([16], "float32", "shared") + threadIdx_x = T.launch_thread("threadIdx.x", 16) + B_1 = T.Buffer((24,), data=B, scope="shared") + A_1 = T.Buffer((16,), data=A.data) + B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] + C_1 = T.Buffer((1,), data=C, scope="local") + C_1[0] = B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] + D_1 = T.Buffer((16,), data=D, scope="shared") + D_1[threadIdx_x] = C_1[0] + E_1 = T.Buffer((16,), data=E.data) + E_1[threadIdx_x] = D_1[threadIdx_x] + + @T.prim_func(private=True) + def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + B_1 = T.allocate([24], "float32", "shared") + C_1 = T.allocate([1], "float32", "local") + D_1 = T.allocate([16], "float32", "shared") + threadIdx_x = T.launch_thread("threadIdx.x", 16) + B_1_1 = T.Buffer((24,), data=B_1, scope="shared") + A_1 = T.Buffer((16,), data=A.data) + B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] + C_1_1 = T.Buffer((1,), data=C_1, scope="local") + C_1_1[0] = B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] + D_1_1 = T.Buffer((16,), data=D_1, scope="shared") + D_1_1[threadIdx_x] = C_1_1[0] + E_1 = T.Buffer((16,), data=E.data) + E_1[threadIdx_x] = D_1_1[threadIdx_x] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +@tvm.testing.requires_cuda +def test_sync_let_stmt(): + @T.prim_func(private=True) + def func(A: T.Buffer((16 * 512), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 16) + A_shared = T.allocate([512], "float32", "shared") + in_thread_A_temp = T.allocate([1], "float32", "local") + cross_thread_A_temp = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared") + for ax0 in range(512): + A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] + in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") + in_thread_A_temp_1[0] = T.float32(0) + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: + in_thread_A_temp_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: + in_thread_A_temp_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: + in_thread_A_temp_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: + in_thread_A_temp_1[0] = A_temp + cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_A_temp_1[0], + T.bool(True), + cross_thread_A_temp_1[0], + threadIdx_x, + ) + + @T.prim_func(private=True) + def expected(A: T.Buffer((8192,), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 16) + A_shared_1 = T.allocate([512], "float32", "shared") + in_thread_A_temp_1 = T.allocate([1], "float32", "local") + cross_thread_A_temp_1 = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_shared_1_1 = T.Buffer((512,), data=A_shared_1, scope="shared") + for ax0 in range(512): + A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0] + in_thread_A_temp_1_1 = T.Buffer((1,), data=in_thread_A_temp_1, scope="local") + in_thread_A_temp_1_1[0] = T.float32(0) + T.tvm_storage_sync("shared") + with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as A_temp: + in_thread_A_temp_1_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) as A_temp: + in_thread_A_temp_1_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) as A_temp: + in_thread_A_temp_1_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) as A_temp: + in_thread_A_temp_1_1[0] = A_temp + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + cross_thread_A_temp_1_1 = T.Buffer((1,), data=cross_thread_A_temp_1, scope="local") + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_A_temp_1_1[0], + T.bool(True), + cross_thread_A_temp_1_1[0], + threadIdx_x, + ) + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +@tilelang.testing.requires_cuda +def test_sync_shared_dyn_stmatrix_loop_hoist(): + @T.prim_func + def func(): + buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn") + tx = T.launch_thread("threadIdx.x", 384) + for i in T.unroll(8): + off = ( + i // 4 * 8192 + + tx // 32 * 1024 + + tx % 16 * 64 + + (tx % 8 // 4 + i % 4 // 2) % 2 * 32 + + (tx % 4 // 2 + i % 2) % 2 * 16 + + (tx % 32 // 16 + tx % 2) % 2 * 8 + ) + T.evaluate( + T.call_intrin( + "handle", + tvm.tir.op.Op.get("tl.ptx_stmatrix"), + T.int32(0), + T.int32(4), + T.tvm_access_ptr( + T.type_annotation("uint8"), + buf_dyn_shmem.data, + off, + 98304 - off, + 2, + ), + T.int32(2), + ) + ) + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared.dyn")(mod) + s = str(mod) + assert 'T.tvm_storage_sync("shared.dyn")' in s + # Ensure the sync appears before the unrolled loop + assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/transform/test_tilelang_transform_warp_specialized.py b/tilelang/original/testing/python/transform/test_tilelang_transform_warp_specialized.py new file mode 100644 index 0000000000000000000000000000000000000000..0171fab82ce0dba7be8aa7ba342d213018d7d40b --- /dev/null +++ b/tilelang/original/testing/python/transform/test_tilelang_transform_warp_specialized.py @@ -0,0 +1,123 @@ +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing +from tvm import tir + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.WarpSpecialized()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tir.transform.LowerOpaqueBlock()(transformed) + + # TODO: fix loop_var equal bug + # tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +M = 512 +N = 512 +K = 512 +dtype = T.float16 +block_M = 64 +block_N = 64 +block_K = 32 + + +def test_warp_specialized(): + @T.prim_func + def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + v = T.launch_thread("threadIdx.x", 128) + with T.block(""): + T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) + T.writes() + A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.alloc_buffer((32,), scope="local") + for k in T.serial(16, annotations={"num_stages": T.int32(3)}): + if v == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) + if v == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) + T.call_extern( + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + + @T.prim_func + def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + v = T.launch_thread("threadIdx.x", 256) + A_shared = T.decl_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.decl_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") + C_local = T.decl_buffer((32,), scope="local") + T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128) + T.attr([128, 128], "kWarpSpecializationScope", 0) + if v >= 128: + T.set_max_nreg(24, 0) + for k in range(16): + T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1)) + if v - 128 == 0: + T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) + if v - 128 == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) + if v - 128 == 0: + T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) + if v - 128 == 0: + T.tma_load( + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) + else: + T.set_max_nreg(240, 1) + for k in range(16): + T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) + T.call_extern( + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/testing/python/utils/test_compress_utils.py b/tilelang/original/testing/python/utils/test_compress_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fc20539eb9bf89a90d7645d9d5d31f6e3f6428 --- /dev/null +++ b/tilelang/original/testing/python/utils/test_compress_utils.py @@ -0,0 +1,39 @@ +import torch +import tilelang +import tilelang.testing + +from tilelang.utils.sparse import compress_sm90, randn_semi_sparse + + +def _test_compress_sm90(M, K, block_k, dtype): + A = randn_semi_sparse(M, K, dtype=dtype, device="cuda") + A_sparse, E = compress_sm90(A, block_k, False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_compress_sm90(): + _test_compress_sm90(1024, 1024, 128, torch.float16) + _test_compress_sm90(1024, 1024, 64, torch.float16) + _test_compress_sm90(1024, 1024, 32, torch.float16) + + _test_compress_sm90(1024, 1024, 128, torch.bfloat16) + _test_compress_sm90(1024, 1024, 64, torch.bfloat16) + _test_compress_sm90(1024, 1024, 32, torch.bfloat16) + + _test_compress_sm90(1024, 1024, 64, torch.float32) + _test_compress_sm90(1024, 1024, 32, torch.float32) + _test_compress_sm90(1024, 1024, 16, torch.float32) + + _test_compress_sm90(1024, 1024, 256, torch.float8_e4m3fn) + _test_compress_sm90(1024, 1024, 128, torch.float8_e4m3fn) + _test_compress_sm90(1024, 1024, 64, torch.float8_e4m3fn) + + _test_compress_sm90(1024, 1024, 256, torch.float8_e5m2) + _test_compress_sm90(1024, 1024, 128, torch.float8_e5m2) + _test_compress_sm90(1024, 1024, 64, torch.float8_e5m2) + + +if __name__ == "__main__": + test_compress_sm90() + print("All tests passed.") diff --git a/tilelang/original/testing/python/webgpu/test_webgpu_codegen.py b/tilelang/original/testing/python/webgpu/test_webgpu_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b199e79d036eea04212ad523c02ed2e7cbcf2c --- /dev/null +++ b/tilelang/original/testing/python/webgpu/test_webgpu_codegen.py @@ -0,0 +1,59 @@ +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) + T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) + + for i, j, k in T.grid(block_M, block_N, block_K): + C_local[i, j] += A_shared[i, k] * B_shared[k, j] + + T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2) + + return main + + +def assert_gemm_codegen( + M, + N, + K, + block_M, + block_N, + block_K, + dtype=T.float16, + accum_dtype=T.float32, +): + func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + # Because the current pass context have been polluted by previous testing. + with tvm.transform.PassContext(), tvm.target.Target("webgpu"): + artifact = tilelang.lower(func, target="webgpu") + + src_code = artifact.kernel_source + + assert src_code is not None + + +def test_gemm_codegen(): + assert_gemm_codegen(1024, 1024, 1024, 16, 16, 16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/original/tilelang/__init__.py b/tilelang/original/tilelang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87176b2091b930408bcb81ffd922004919564e24 --- /dev/null +++ b/tilelang/original/tilelang/__init__.py @@ -0,0 +1,154 @@ +import sys +import os +import ctypes + +import logging +import warnings +from pathlib import Path +from tqdm.auto import tqdm + + +def _compute_version() -> str: + """Return the package version without being polluted by unrelated installs. + + Preference order: + 1) If running from a source checkout (VERSION file present at repo root), + use the dynamic version from version_provider (falls back to plain VERSION). + 2) Otherwise, use importlib.metadata for the installed distribution. + 3) As a last resort, return a dev sentinel. + """ + try: + repo_root = Path(__file__).resolve().parent.parent + version_file = repo_root / "VERSION" + if version_file.is_file(): + try: + from version_provider import dynamic_metadata # type: ignore + + return dynamic_metadata("version") + except Exception: + # Fall back to the raw VERSION file if provider isn't available. + return version_file.read_text().strip() + except Exception: + # If any of the above fails, fall through to installed metadata. + pass + + try: + from importlib.metadata import version as _dist_version # py3.8+ + + return _dist_version("tilelang") + except Exception as exc: + warnings.warn( + f"tilelang version metadata unavailable ({exc!r}); using development version.", + RuntimeWarning, + stacklevel=2, + ) + return "0.0.dev0" + + +__version__ = _compute_version() + + +class TqdmLoggingHandler(logging.Handler): + """Custom logging handler that directs log output to tqdm progress bar to avoid interference.""" + + def __init__(self, level=logging.NOTSET): + """Initialize the handler with an optional log level.""" + super().__init__(level) + + def emit(self, record): + """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted.""" + try: + msg = self.format(record) + tqdm.write(msg) + except Exception: + self.handleError(record) + + +def set_log_level(level): + """Set the logging level for the module's logger. + + Args: + level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). + OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' + """ + if isinstance(level, str): + level = getattr(logging, level.upper(), logging.INFO) + logger = logging.getLogger(__name__) + logger.setLevel(level) + + +def _init_logger(): + """Initialize the logger specific for this module with custom settings and a Tqdm-based handler.""" + logger = logging.getLogger(__name__) + handler = TqdmLoggingHandler() + formatter = logging.Formatter( + fmt="%(asctime)s [TileLang:%(name)s:%(levelname)s]: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.propagate = False + set_log_level("INFO") + + +_init_logger() + +logger = logging.getLogger(__name__) + +from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 +from .env import env as env # noqa: F401 + +import tvm +import tvm.base # noqa: F401 +from tvm import DataType # noqa: F401 + +# Setup tvm search path before importing tvm +from . import libinfo + + +def _load_tile_lang_lib(): + """Load Tile Lang lib""" + if sys.platform.startswith("win32") and sys.version_info >= (3, 8): + for path in libinfo.get_dll_directories(): + os.add_dll_directory(path) + # pylint: disable=protected-access + lib_name = "tilelang" if tvm.base._RUNTIME_ONLY else "tilelang_module" + # pylint: enable=protected-access + lib_path = libinfo.find_lib_path(lib_name) + return ctypes.CDLL(lib_path), lib_path + + +# only load once here +if env.SKIP_LOADING_TILELANG_SO == "0": + _LIB, _LIB_PATH = _load_tile_lang_lib() + +from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401 +from .profiler import Profiler # noqa: F401 +from .cache import clear_cache # noqa: F401 + +from .utils import ( + TensorSupplyType, # noqa: F401 + deprecated, # noqa: F401 +) +from .layout import ( + Layout, # noqa: F401 + Fragment, # noqa: F401 +) +from . import ( + analysis, # noqa: F401 + transform, # noqa: F401 + language, # noqa: F401 + engine, # noqa: F401 + tools, # noqa: F401 +) +from .language.v2 import dtypes # noqa: F401 +from .autotuner import autotune # noqa: F401 +from .transform import PassConfigKey # noqa: F401 + +from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401 + +from .math import * # noqa: F403 + +from . import ir # noqa: F401 + +from . import tileop # noqa: F401 diff --git a/tilelang/original/tilelang/_ffi_api.py b/tilelang/original/tilelang/_ffi_api.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6421bf774f375e437515fe1878a1468c2c3e65 --- /dev/null +++ b/tilelang/original/tilelang/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for tilelang""" + +import tvm_ffi + +# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); +tvm_ffi.init_ffi_api("tl", __name__) diff --git a/tilelang/original/tilelang/analysis/__init__.py b/tilelang/original/tilelang/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4090d80e7c6405a4cc7c27241430706ea96bd1 --- /dev/null +++ b/tilelang/original/tilelang/analysis/__init__.py @@ -0,0 +1,6 @@ +"""Tilelang IR analysis & visitors.""" + +from .ast_printer import ASTPrinter # noqa: F401 +from .nested_loop_checker import NestedLoopChecker # noqa: F401 +from .fragment_loop_checker import FragmentLoopChecker # noqa: F401 +from .layout_visual import LayoutVisual # noqa: F401 diff --git a/tilelang/original/tilelang/analysis/ast_printer.py b/tilelang/original/tilelang/analysis/ast_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..e634e02713e4a277940aa487dab2a10d45151dd0 --- /dev/null +++ b/tilelang/original/tilelang/analysis/ast_printer.py @@ -0,0 +1,23 @@ +from tvm import tir +from tvm.tir import PrimFunc +from tvm.tir.transform import prim_func_pass +from tvm.tir.stmt_functor import ir_transform + + +def ASTPrinter(): + """ + Print the AST of a given tilelang module for debugging. + """ + + def pre_visit(statement: tir.Stmt) -> None: + """ + Pre-order visitor to print all visited statements. + """ + + print(f"Visiting statement: {type(statement)}, {statement}") + + def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: + new_body = ir_transform(func.body, pre_visit, None) + return func.with_body(new_body) + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/original/tilelang/analysis/fragment_loop_checker.py b/tilelang/original/tilelang/analysis/fragment_loop_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..94900a5cc6c9cc167e539a47c6bf4d3e70b5db21 --- /dev/null +++ b/tilelang/original/tilelang/analysis/fragment_loop_checker.py @@ -0,0 +1,100 @@ +from __future__ import annotations +from tvm import tir +from tvm.tir import PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm +from tvm.tir.transform import prim_func_pass +from tvm.tir.stmt_functor import post_order_visit + + +@tir.functor.visitor +class _LoopVarUseAnalyzer(PyStmtExprVisitor): + """Analyze whether a loop variable is used in the given expr.""" + + def __init__(self, var: Var) -> None: + super().__init__() + self.var = var + self.used = False + + def visit_var_(self, op: Var) -> None: + if op == self.var: + self.used = True + # Don't recursively visit children to avoid infinite recursion + + +def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: + """ + Collect local buffer accesses in the loop body. + + Args: + statement: The TIR statement to analyze + + Returns: + Tuple of buffer accesses in the loop body. + """ + + buffer_accesses = [] + + def visit_buffer_access(node): + if isinstance(node, (BufferLoad, BufferStore)) and node.buffer.scope().startswith("local"): + buffer_accesses.append(node) + + post_order_visit(statement, visit_buffer_access) + + return buffer_accesses + + +@tir.functor.visitor +class _FragmentLoopCheckVisitor(PyStmtExprVisitor): + def __init__(self) -> None: + super().__init__() + + def visit_for_(self, op: For) -> None: + if op.kind == tir.ForKind.PARALLEL: + # Fuse consecutive parallel loops + # Other nested cases are all invalid in TileLang. + loops = [op] + child = op.body + while isinstance(child, For) and child.kind == tir.ForKind.PARALLEL: + loops.append(child) + child = child.body + + loops_with_symbolic_ranges = [] + for loop in loops: + if not (isinstance(loop.min, IntImm) and isinstance(loop.extent, IntImm)): + loops_with_symbolic_ranges.append(loop) + + if len(loops_with_symbolic_ranges) > 0: + buffer_accesses = collect_local_buffer_accesses(child) + for loop in loops_with_symbolic_ranges: + for buffer_access in buffer_accesses: + indices = buffer_access.indices + analyzer = _LoopVarUseAnalyzer(loop.loop_var) + for index in indices: + analyzer.visit_expr(index) + if analyzer.used: + raise ValueError( + "[Tilelang Semantic Check] " + f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index " + "a local/fragment buffer, which is not allowed in Tilelang." + ) + + return + + self.visit_stmt(op.body) + + +def FragmentLoopChecker(): + """ + When using T.Parallel over a local/fragment buffer, there are several restrictions: + to ensure that the parallelization is valid. + + 1. The range of loop can not be symbolic. + + Returns: + A prim_func_pass that applies the transformation + """ + + def pass_fn(func: PrimFunc, mod, ctx): + _FragmentLoopCheckVisitor().visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/original/tilelang/analysis/layout_visual.py b/tilelang/original/tilelang/analysis/layout_visual.py new file mode 100644 index 0000000000000000000000000000000000000000..141fb808c49f978fb2fe3089e5903f814bda30fc --- /dev/null +++ b/tilelang/original/tilelang/analysis/layout_visual.py @@ -0,0 +1,86 @@ +import tilelang.language as T +from tvm import tir +from tvm.tir import PyStmtExprVisitor + +from tvm.tir.transform import prim_func_pass +from tilelang.tools.plot_layout import plot_layout + + +def print_fragment_format(layout: T.Fragment) -> str: + """ + Format fragment layout information into a human-readable string. + + Parameters + ---------- + layout : T.Fragment + The fragment layout to format + + Returns + ------- + str + Formatted string showing shape, thread mapping, and index mapping + """ + if isinstance(layout, T.Fragment): + input_shape = layout.get_input_shape() + output_shape = layout.get_output_shape() + lines = [f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}"] + print("\n".join(lines)) + else: + raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") + + +@tir.functor.visitor +class _LayoutVisualVisitor(PyStmtExprVisitor): + """ + User-friendly pass which visualizes fragment layouts inferred during compilation. + + In TileLang, Fragment layouts describe: + - How logical indices (e.g., [i, j]) map to thread IDs + - How logical indices map to register file locations within each thread + - The shape transformation from input dimensions to output dimensions + + This pass generates two types of output: + 1. Textual output: A human-readable description printed to console + 2. Visual diagrams: Color-coded plots saved to files (PDF, PNG, SVG formats) + + Configuration: + The pass is controlled by the TL_ENABLE_LAYOUT_VISUALIZATION configuration option. + The configuration accepts string values: + + - Empty string or not set: Pass does nothing (default, disabled) + - "png": Generate PNG format only (recommended for quick inspection) + - "pdf": Generate PDF format only (recommended for documentation) + - "svg": Generate SVG format only (recommended for web/vector graphics) + - "all": Generate all formats (PDF, PNG, SVG) + - "png,svg": Generate multiple formats (comma-separated) + """ + + def __init__(self, formats: list[str] = ""): + super().__init__() + self.layout_found = [] + self.processed_layouts = set() + self.formats_list = [f for f in formats if f != "txt"] + + def visit_block_(self, op: tir.Block) -> None: + if "layout_map" in op.annotations: + layout_map = op.annotations["layout_map"] + + for key, layout in layout_map.items(): + if isinstance(layout, T.Fragment): + layout_id = str(layout) + if layout_id not in self.processed_layouts: + print(f"{key} inferenced layout:") + print_fragment_format(layout) + for fmt in self.formats_list: + plot_layout(layout, name=f"{key}_layout", formats=fmt) + self.processed_layouts.add(layout_id) + + # super().visit_block_(op) + + +def LayoutVisual(formats: str = ""): + def pass_fn(func: tir.PrimFunc, mod, ctx): + _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/original/tilelang/analysis/nested_loop_checker.py b/tilelang/original/tilelang/analysis/nested_loop_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..51da7f4c8ea31c8957e79c366ab7981bd67102af --- /dev/null +++ b/tilelang/original/tilelang/analysis/nested_loop_checker.py @@ -0,0 +1,119 @@ +from tvm import tir +from tvm.tir import ( + For, + Call, + PrimFunc, + PyStmtExprVisitor, +) +from tvm.tir.transform import prim_func_pass + + +def is_pipelined_for(op: For) -> bool: + """Check if a for loop is pipelined.""" + + anno_keys = ["num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", "tl_pipeline_group"] + return any(key in op.annotations for key in anno_keys) + + +def is_tile_op(op: Call) -> bool: + """Check if a call is a tile-op""" + + return op.op.get_attr("TLOpBuilder") is not None + + +@tir.functor.visitor +class _NestedLoopCheckVisitor(PyStmtExprVisitor): + def __init__(self) -> None: + super().__init__() + self.in_parallel_context = False + + def visit_for_(self, op: For) -> None: + if op.kind == tir.ForKind.PARALLEL: + child = op.body + + # Special case: continuous nested parallel loop is allowed. + if isinstance(child, tir.For) and child.kind == tir.ForKind.PARALLEL: + self.visit_stmt(child) + return + + # Otherwise + if self.in_parallel_context: + raise ValueError("[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure.") + self.in_parallel_context = True + super().visit_for_(op) + self.in_parallel_context = False + return + elif is_pipelined_for(op): + if self.in_parallel_context: + raise ValueError( + "[Tilelang Semantic Check] Pipelined loop cannot be nested inside a parallel loop. Please check your loop structure." + ) + + super().visit_for_(op) + + def visit_call_(self, op: Call) -> None: + if self.in_parallel_context and is_tile_op(op): + raise ValueError( + f'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "{op.op}".' + ) + + +def NestedLoopChecker(): + """ + User-friendly pass which identifies any invalid any nested-loop pattern. + + Nested loops is an annoying problem in tilelang or other polyhedral-style compilers. + It contains many corner cases and undefined behaviours. + + In tilelang, there are four loops: + T.serial + T.Parallel (T.vectorized) + T.Pipelined + T.Persistent + + T.Persistent is a new feature which we do not consider here. + + We define the following rules: + - (Rule 1) T.serial can be nested inside any other loop type without restriction. + - (Rule 2) Consecutive T.Parallel nested loops are not allowed. Including any TileOp (T.copy, etc.) which has + "parallel" behaviours is also forbidden. + + Examples: + for i in T.Parallel(M): + stmt + for j in T.Parallel(N): + ... + + for i in T.Parallel(M): + T.copy(A, B) # forbidden! + + **Only a special case is allowed: strict continuous Parallel loops.** Since we can fuse them into a single T.Parallel loop. + Example: + + for i in T.Parallel(M): + for j in T.Parallel(N): + ... # allowed + - (Rule 3) T.Pipelined inside a T.Parallel is forbidden. + + Examples: + for i in T.Parallel(M): + for j in T.Pipelined(K): # forbidden! + ... + + for i in T.Pipelined(K): + for j in T.Parallel(N): # allowed, ok + ... + + In summary, the problem mainly lies in the "T.Parallel". We highly recommend to use + T.Parallel to implement a tiled operator inside a kernel (e.g. T.gemm level) instead of other usages. + This guideline can help you avoid most of the issues. + + Returns: + A prim_func_pass that applies the transformation + """ + + def pass_fn(func: PrimFunc, mod, ctx): + _NestedLoopCheckVisitor().visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/original/tilelang/autotuner/__init__.py b/tilelang/original/tilelang/autotuner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac292c769f3fc728138c919c1b2d79b3cfdda030 --- /dev/null +++ b/tilelang/original/tilelang/autotuner/__init__.py @@ -0,0 +1,8 @@ +from .tuner import ( + autotune, # noqa: F401 + AutoTuner, # noqa: F401 +) +from .capture import ( + set_autotune_inputs, # noqa: F401 + get_autotune_inputs, # noqa: F401 +) diff --git a/tilelang/original/tilelang/autotuner/capture.py b/tilelang/original/tilelang/autotuner/capture.py new file mode 100644 index 0000000000000000000000000000000000000000..428a6da9047bea8d517dbc16dd5e403bc5c29d48 --- /dev/null +++ b/tilelang/original/tilelang/autotuner/capture.py @@ -0,0 +1,126 @@ +from __future__ import annotations +import threading +from typing import Any + +# Use thread local to store the stack +# This is to avoid the cross-thread interference +_local = threading.local() + + +class CaptureStack: + """ + A simple stack implementation for capturing items in a thread-local context. + Used to manage a stack of items (e.g., input tensors) for auto-tuning capture. + """ + + def __init__(self): + # Initialize an empty list to use as the stack + self.stack = [] + + def push(self, item): + """ + Push an item onto the top of the stack. + + Args: + item: The item to be pushed onto the stack. + """ + self.stack.append(item) + + def pop(self): + """ + Pop and return the top item from the stack. + + Returns: + The item at the top of the stack. + + Raises: + IndexError: If the stack is empty. + """ + return self.stack.pop() + + def top(self): + """ + Return the item at the top of the stack without removing it. + + Returns: + The item at the top of the stack. + + Raises: + IndexError: If the stack is empty. + """ + return self.stack[-1] + + def size(self): + """ + Return the number of items in the stack. + + Returns: + int: The size of the stack. + """ + return len(self.stack) + + def __len__(self): + """ + Return the number of items in the stack (len operator support). + + Returns: + int: The size of the stack. + """ + return len(self.stack) + + def __bool__(self): + """ + Return True if the stack is not empty, False otherwise. + + Returns: + bool: Whether the stack contains any items. + """ + return bool(self.stack) + + +def _get_current_stack() -> CaptureStack: + if not hasattr(_local, "capture_stack"): + _local.capture_stack = CaptureStack() + return _local.capture_stack + + +class AutotuneInputsCapture: + __slots__ = "tensors" + + def __init__(self, tensors: list[Any]): + self.tensors = tensors + + def __enter__(self) -> None: + _get_current_stack().push(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + _get_current_stack().pop() + + +def set_autotune_inputs(*args) -> AutotuneInputsCapture: + """Set input tensors for auto-tuning. + + This function creates a context manager for capturing input tensors + during the auto-tuning process. It supports both: + set_autotune_inputs(a, b, c) + set_autotune_inputs([a, b, c]) + + Args: + *args: Either a single list/tuple of tensors, or multiple tensor arguments. + + Returns: + AutotuneInputsCapture: A context manager for auto-tuning inputs. + """ + if len(args) == 1 and isinstance(args[0], (list, tuple)): + tensors = list(args[0]) + else: + tensors = list(args) + return AutotuneInputsCapture(tensors) + + +def get_autotune_inputs() -> list[Any] | None: + """ + Get the current autotune inputs from the stack. + """ + stack = _get_current_stack() + return stack.top().tensors if stack else None diff --git a/tilelang/original/tilelang/autotuner/param.py b/tilelang/original/tilelang/autotuner/param.py new file mode 100644 index 0000000000000000000000000000000000000000..69ad49c79d91127c7676e904fd256f2f8f56f659 --- /dev/null +++ b/tilelang/original/tilelang/autotuner/param.py @@ -0,0 +1,448 @@ +"""The auto-tune parameters.""" + +from __future__ import annotations + +import tilelang +from tilelang import tvm as tvm +from tvm.tir import PrimFunc +from tvm.target import Target +from typing import Callable, Literal, Any +from dataclasses import dataclass +from pathlib import Path + +from tilelang.jit import JITKernel +import cloudpickle +import os +from tilelang.engine.param import KernelParam +from tilelang import logger +import json +import hashlib +import uuid +from tilelang import env +from tvm.runtime import Executable + +BEST_CONFIG_PATH = "best_config.json" +FUNCTION_PATH = "function.pkl" +LATENCY_PATH = "latency.json" + +# Align file names with cache/kernel_cache.py +DEVICE_KERNEL_PATH = "device_kernel.cu" +HOST_KERNEL_PATH = "host_kernel.cu" +EXECUTABLE_PATH = "executable.so" +KERNEL_LIB_PATH = "kernel_lib.so" +KERNEL_CUBIN_PATH = "kernel.cubin" +KERNEL_PY_PATH = "kernel.py" +PARAMS_PATH = "params.pkl" + + +@dataclass(frozen=True) +class CompileArgs: + """Compile arguments for the auto-tuner. Detailed description can be found in `tilelang.jit.compile`. + Attributes: + out_idx: List of output tensor indices. + execution_backend: Execution backend to use for kernel execution (default: "auto"). + target: Compilation target, either as a string or a TVM Target object (default: "auto"). + target_host: Target host for cross-compilation (default: None). + verbose: Whether to enable verbose output (default: False). + pass_configs: Additional keyword arguments to pass to the Compiler PassContext. + Refer to `tilelang.PassConfigKey` for supported options. + """ + + out_idx: list[int] | int | None = None + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto" + target: Literal["auto", "cuda", "hip"] = "auto" + target_host: str | Target = None + verbose: bool = False + pass_configs: dict[str, Any] | None = None + + def compile_program(self, program: PrimFunc): + return tilelang.compile( + program, + out_idx=self.out_idx, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + ) + + def __hash__(self): + data = { + "execution_backend": self.execution_backend, + "target": str(self.target), + "target_host": str(self.target_host) if self.target_host else None, + "verbose": self.verbose, + "pass_configs": json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None, + } + + hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8")) + return int.from_bytes(hash_obj.digest(), byteorder="big") + + +@dataclass(frozen=True) +class ProfileArgs: + """Profile arguments for the auto-tuner. + + Attributes: + warmup: Number of warmup iterations. + rep: Number of repetitions for timing. + timeout: Maximum time per configuration. + supply_type: Type of tensor supply mechanism. + ref_prog: Reference program for correctness validation. + supply_prog: Supply program for input tensors. + out_idx: Union[List[int], int] = -1 + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto + ref_prog: Callable = None + supply_prog: Callable = None + rtol: float = 1e-2 + atol: float = 1e-2 + max_mismatched_ratio: float = 0.01 + skip_check: bool = False + manual_check_prog: Callable = None + cache_input_tensors: bool = True + """ + + warmup: int = 25 + rep: int = 100 + timeout: int = 30 + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto + ref_prog: Callable = None + supply_prog: Callable = None + rtol: float = 1e-2 + atol: float = 1e-2 + max_mismatched_ratio: float = 0.01 + skip_check: bool = False + manual_check_prog: Callable = None + cache_input_tensors: bool = True + + def __hash__(self): + data = { + "warmup": self.warmup, + "rep": self.rep, + "timeout": self.timeout, + "supply_type": str(self.supply_type), + "rtol": self.rtol, + "atol": self.atol, + "max_mismatched_ratio": self.max_mismatched_ratio, + } + hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8")) + return int.from_bytes(hash_obj.digest(), byteorder="big") + + +@dataclass(frozen=True) +class AutotuneResult: + """Results from auto-tuning process. + + Attributes: + latency: Best achieved execution latency. + config: Configuration that produced the best result. + ref_latency: Reference implementation latency. + libcode: Generated library code. + func: Optimized function. + kernel: Compiled kernel function. + """ + + latency: float | None = None + config: dict | None = None + ref_latency: float | None = None + libcode: str | None = None + func: Callable | None = None + kernel: Callable | None = None + + @staticmethod + def _load_binary(path: str): + with open(path, "rb") as file: + binary = file.read() + return binary + + @staticmethod + def _safe_write_file(path: str, mode: str, operation: Callable[[Any], None]): + # Random a temporary file within the same FS as the cache directory + tmp_dir = env.TILELANG_TMP_DIR + os.makedirs(tmp_dir, exist_ok=True) + temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}") + with open(temp_path, mode) as temp_file: + operation(temp_file) + # Use atomic POSIX replace, so other processes cannot see a partial write + os.replace(temp_path, path) + + @staticmethod + def _safe_write_executable(executable: Executable, path: str): + tmp_dir = env.TILELANG_TMP_DIR + os.makedirs(tmp_dir, exist_ok=True) + temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}.so") + executable.export_library(temp_path) + os.replace(temp_path, path) + + def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False): + """ + Persists a compiled kernel to disk cache. + + Args: + cache_path (Path): The root path for the cache files. + kernel (JITKernel): The compiled kernel to be saved. + verbose (bool): Enable verbose log messages. + + Note: + Saves the following files: + - kernel.cu: The compiled kernel source code + - wrapped_kernel.cu: The wrapped kernel source code + - kernel_lib.so: The compiled kernel library + - params.pkl: The serialized kernel parameters + """ + os.makedirs(cache_path, exist_ok=True) # Ensure directory exists + + # Save device kernel source code + try: + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + if verbose: + logger.debug(f"Saving kernel source code to file: {device_kernel_path}") + if kernel.kernel_source is not None: + self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source)) + except Exception as e: + logger.error(f"Error saving kernel source code to disk: {e}") + + # Save host kernel source code (wrapped) + try: + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + if verbose: + logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") + # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel + if kernel.execution_backend == "tvm_ffi": + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source())) + else: + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source())) + except Exception as e: + logger.error(f"Error saving wrapped kernel source code to disk: {e}") + + # Save kernel library (backend-specific) + try: + if kernel.execution_backend == "nvrtc": + kernel_lib_file = KERNEL_CUBIN_PATH + elif kernel.execution_backend == "tvm_ffi": + kernel_lib_file = EXECUTABLE_PATH + else: + kernel_lib_file = KERNEL_LIB_PATH + + kernel_lib_path = os.path.join(cache_path, kernel_lib_file) + + if kernel.execution_backend == "nvrtc": + # Save cubin and python helper file + src_lib_path = kernel.adapter.libpath + kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) + py_src_path = src_lib_path.replace(".cubin", ".py") + if verbose: + logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") + self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path))) + if verbose: + logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) + elif kernel.execution_backend == "tvm_ffi": + executable = kernel.adapter.executable + if verbose: + logger.debug(f"Saving kernel executable to file: {kernel_lib_path}") + self._safe_write_executable(executable, kernel_lib_path) + else: + src_lib_path = kernel.adapter.libpath + if verbose: + logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) + + except Exception as e: + logger.error(f"Error saving kernel library to disk: {e}") + + # Save kernel parameters + try: + params_path = os.path.join(cache_path, PARAMS_PATH) + if verbose: + logger.debug(f"Saving kernel parameters to disk: {params_path}") + self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f)) + except Exception as e: + logger.error(f"Error saving kernel parameters to disk: {e}") + + def _load_kernel_from_disk( + self, + cache_path: Path, + target: str | Target = "auto", + target_host: str | Target = None, + out_idx: list[int] | int | None = None, + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", + pass_configs: dict = None, + compile_flags: list[str] | str | None = None, + func: Callable = None, + verbose: bool = False, + ) -> JITKernel: + """ + Loads a previously compiled kernel from disk cache. + + Args: + key (str): The hash key identifying the kernel. + target (Union[str, Target]): Compilation target platform. Defaults to "auto". + target_host (Union[str, Target], optional): Host target platform. + out_idx (List[int], optional): Indices specifying which outputs to return. + execution_backend (Literal): Backend type for execution. Defaults to "cython". + pass_configs (dict, optional): Configuration for compiler passes. + func (Callable, optional): The original function. + verbose (bool): Enable verbose log messages. + + Returns: + JITKernel: The loaded kernel if found, None otherwise. + """ + + if not os.path.exists(cache_path): + return None + + # Resolve backend to pick correct file names + if execution_backend == "nvrtc": + kernel_lib_file = KERNEL_CUBIN_PATH + elif execution_backend == "tvm_ffi": + kernel_lib_file = EXECUTABLE_PATH + else: + kernel_lib_file = KERNEL_LIB_PATH + + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + kernel_lib_path = os.path.join(cache_path, kernel_lib_file) + params_path = os.path.join(cache_path, PARAMS_PATH) + + if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): + return None + + device_kernel_source: str | None = None + host_kernel_source: str | None = None + kernel_params: list[KernelParam] | None = None + + # Load optional device kernel source + try: + if verbose: + logger.debug(f"Loading kernel source code from file: {device_kernel_path}") + with open(device_kernel_path) as f: + device_kernel_source = f.read() + except Exception as e: + logger.error(f"Error loading kernel source code from disk: {e}") + + # Load optional host kernel source + try: + if verbose: + logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") + with open(host_kernel_path) as f: + host_kernel_source = f.read() + except Exception as e: + logger.error(f"Error loading host kernel source code from disk: {e}") + + # Load kernel parameters + try: + if verbose: + logger.debug(f"Loading kernel parameters from file: {params_path}") + with open(params_path, "rb") as f: + kernel_params = cloudpickle.load(f) + except Exception as e: + logger.error(f"Error loading kernel parameters from disk: {e}") + + if host_kernel_source and device_kernel_source and kernel_params: + return JITKernel.from_database( + func=func, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + params=kernel_params, + target=target, + target_host=target_host, + out_idx=out_idx, + execution_backend=execution_backend, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + else: + return None + + def save_to_disk(self, path: Path, verbose: bool = False): + if not os.path.exists(path): + os.makedirs(path) + + # save best config (atomic) + if verbose: + logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") + self._safe_write_file(str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) + + # save function (atomic) + if verbose: + logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") + self._safe_write_file(str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) + + # save ref latency (atomic) + if verbose: + logger.debug(f"Saving latency to file: {path / LATENCY_PATH}") + self._safe_write_file( + str(path / LATENCY_PATH), + "w", + lambda f: json.dump( + { + "latency": self.latency, + "ref_latency": self.ref_latency, + }, + f, + ), + ) + + # save kernel + self._save_kernel_to_disk(path, self.kernel) + + @classmethod + def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult: + if not os.path.exists(path): + return None + + verbose = compile_args.verbose + # Normalize target and resolve execution backend for loading + from tilelang.utils.target import determine_target as _determine_target + from tilelang.jit.execution_backend import resolve_execution_backend + + norm_target = Target(_determine_target(compile_args.target)) if isinstance(compile_args.target, str) else compile_args.target + requested_backend = compile_args.execution_backend + resolved_backend = resolve_execution_backend(requested_backend, norm_target) + # load best config + if verbose: + logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}") + with open(path / BEST_CONFIG_PATH) as f: + config = json.load(f) + + # load function + if verbose: + logger.debug(f"Loading function from file: {path / FUNCTION_PATH}") + with open(path / FUNCTION_PATH, "rb") as f: + func = cloudpickle.load(f) + + # load latency + if verbose: + logger.debug(f"Loading latency from file: {path / LATENCY_PATH}") + with open(path / LATENCY_PATH) as f: + latency = json.load(f) + latency, ref_latency = latency["latency"], latency["ref_latency"] + + kernel = cls._load_kernel_from_disk( + cls, + path, + norm_target, + compile_args.target_host, + compile_args.out_idx, + resolved_backend, + compile_args.pass_configs, + None, # compile_flags not tracked here + func, + ) + if kernel is None: + return None + kernel.update_tuner_result( + config=config, + latency=latency, + ref_latency=ref_latency, + ) + result = cls( + config=config, + func=func, + kernel=kernel, + libcode=kernel.get_kernel_source(), + latency=latency, + ref_latency=ref_latency, + ) + return result diff --git a/tilelang/original/tilelang/autotuner/tuner.py b/tilelang/original/tilelang/autotuner/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..8d95037395b1af5f6ede1149a952b5a1d7c460ef --- /dev/null +++ b/tilelang/original/tilelang/autotuner/tuner.py @@ -0,0 +1,771 @@ +"""The auto-tune module for tilelang programs. + +This module provides functionality for auto-tuning tilelang programs, including JIT compilation +and performance optimization through configuration search. +""" + +from __future__ import annotations +from dataclasses import dataclass + +import tilelang +from tilelang import tvm as tvm +from tilelang.jit import JITImpl +from tilelang.jit.kernel import JITKernel +from tvm.tir import PrimFunc, Var +from tvm.target import Target +import inspect +from functools import partial +from typing import Callable, Generic, Literal, Any, TypeVar + +# Python 3.9 compatibility for ParamSpec +try: + from typing import ParamSpec +except ImportError: # Python < 3.10 + from typing_extensions import ParamSpec +from tqdm.auto import tqdm +import logging +import concurrent.futures +import torch +import os +import sys +import signal +import json +import hashlib +import threading +import traceback +from pathlib import Path + +from tilelang import env +from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult +from tilelang.utils.language import get_prim_func_name +from tilelang.autotuner.capture import get_autotune_inputs +from tilelang.utils.target import determine_target +from tilelang import __version__ + + +class TimeoutException(Exception): + pass + + +def timeout_handler(signum, frame): + raise TimeoutException("Operation timed out") + + +def run_with_timeout(func, timeout, *args, **kwargs): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout) + try: + result = func(*args, **kwargs) + except Exception as e: + raise e + finally: + signal.alarm(0) + return result + + +# Configure logging for the autotuner module +# TODO: Consider creating a common logger in utils +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.propagate = False + +# Lazy handler initialization flag +_logger_handlers_initialized = False + + +def _init_logger_handlers(): + global _logger_handlers_initialized + if _logger_handlers_initialized: + return + formatter = logging.Formatter("%(asctime)s %(levelname)s:%(message)s") + file_handler = logging.FileHandler("autotuner.log", mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + _logger_handlers_initialized = True + + +def get_available_cpu_count() -> int: + """Gets the number of CPU cores available to the current process.""" + try: + cpu_count = len(os.sched_getaffinity(0)) + except AttributeError: + cpu_count = os.cpu_count() + + return cpu_count or 1 + + +class AutoTuner: + """Auto-tuner for tilelang programs. + + This class handles the auto-tuning process by testing different configurations + and finding the optimal parameters for program execution. + + Args: + fn: The function to be auto-tuned. + configs: List of configurations to try during auto-tuning. + """ + + compile_args = CompileArgs() + profile_args = ProfileArgs() + + _kernel_parameters: tuple[str, ...] | None = None + _function_parameters: dict[str, Any] | None = None + _lock = threading.Lock() # For thread safety + _memory_cache = {} # In-memory cache dictionary + cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" + + def __init__(self, fn: Callable, configs): + self.fn = fn + self.configs = configs + self.ref_latency_cache = None + self.jit_input_tensors = None + self.ref_input_tensors = None + self.jit_compile = None + + @classmethod + def from_kernel(cls, kernel: Callable, configs): + """Create an AutoTuner instance from a kernel function. + + Args: + kernel: The kernel function to auto-tune. + configs: List of configurations to try. + + Returns: + AutoTuner: A new AutoTuner instance. + """ + return cls(kernel, configs) + + def set_compile_args( + self, + out_idx: list[int] | int | None = None, + target: Literal["auto", "cuda", "hip", "metal"] = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + target_host: str | Target = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + ): + """Set compilation arguments for the auto-tuner. + + Args: + out_idx: List of output tensor indices. + target: Target platform. + execution_backend: Execution backend to use for kernel execution. + target_host: Target host for cross-compilation. + verbose: Whether to enable verbose output. + pass_configs: Additional keyword arguments to pass to the Compiler PassContext. + + Returns: + AutoTuner: Self for method chaining. + """ + # Normalize target to a concrete TVM Target and resolve execution backend + t = Target(determine_target(target)) + from tilelang.jit.execution_backend import resolve_execution_backend + + resolved_backend = resolve_execution_backend(execution_backend, t) + + self.compile_args = CompileArgs( + out_idx=out_idx, + target=t, + execution_backend=resolved_backend, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + ) + + return self + + def set_profile_args( + self, + warmup: int = 25, + rep: int = 100, + timeout: int = 30, + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False, + ): + """Set profiling arguments for the auto-tuner. + + Args: + supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided. + ref_prog: Reference program for validation. + supply_prog: Supply program for input tensors. + rtol: Relative tolerance for validation. + atol: Absolute tolerance for validation. + max_mismatched_ratio: Maximum allowed mismatch ratio. + skip_check: Whether to skip validation. + manual_check_prog: Manual check program for validation. + cache_input_tensors: Whether to cache input tensors. + warmup: Number of warmup iterations. + rep: Number of repetitions for timing. + timeout: Maximum time per configuration. + + Returns: + AutoTuner: Self for method chaining. + """ + # If the program is under `with set_autotune_inputs` context, + # the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead. + if get_autotune_inputs() is not None: + if supply_prog is not None: + logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.") + supply_prog = lambda _: get_autotune_inputs() # noqa: E731 + + self.profile_args = ProfileArgs( + supply_type=supply_type, + ref_prog=ref_prog, + supply_prog=supply_prog, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + skip_check=skip_check, + manual_check_prog=manual_check_prog, + cache_input_tensors=cache_input_tensors, + warmup=warmup, + rep=rep, + timeout=timeout, + ) + + # If a custom `supply_prog` is provided, the profiler's `supply_type` setting + # becomes ineffective. The custom supply program will be used instead. + if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: + logger.warning("Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None.") + + return self + + def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]): + # for cache key generation + self._kernel_parameters = k_parameters + self._function_parameters = f_parameters + + def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: + """Generate a cache key for the auto-tuning process.""" + + def _normalize_param(value): + if isinstance(value, Var): + return str(value) + if isinstance(value, (list, tuple)): + return [_normalize_param(v) for v in value] + if isinstance(value, dict): + return {str(k): _normalize_param(v) for k, v in value.items()} + return value + + # extract parameters from the function signature + op_parameters = [] + for _, default_value in parameters.items(): + if default_value.default is not inspect.Parameter.empty: + op_parameters.append(default_value.default) + + if self._kernel_parameters is not None: + op_parameters += _normalize_param(self._kernel_parameters) + + func_source = inspect.getsource(self.fn) + key_data = { + "version": __version__, + "op_parameters": tuple(op_parameters), + "extra_parameters": extra_parameters, + "func_source": func_source, + "configs": self.configs, + "compile_args": hash(self.compile_args), + "profile_args": hash(self.profile_args), + } + # Sort keys to ensure consistency + key_string = json.dumps(key_data, sort_keys=True) + return hashlib.sha256(key_string.encode()).hexdigest() + + def _save_result_to_disk(self, key, result: AutotuneResult): + result.save_to_disk(self.cache_dir / key, self.compile_args.verbose) + + def _load_result_from_disk(self, key) -> AutotuneResult: + result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args) + return result + + def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): + """Run the auto-tuning process. + + Args: + warmup: Number of warmup iterations. + rep: Number of repetitions for timing. + timeout: Maximum time per configuration. + + Returns: + AutotuneResult: Results of the auto-tuning process. + """ + _init_logger_handlers() + + sig = inspect.signature(self.fn) + parameters = sig.parameters + + # NOTE(chaofan): We need to extract some parameters from the closure. + # Consider the case: + # def gemm(M, N, K): + # def kernel(...) + # If we only extract source, M/N/K will be symbolic and there will be cache problem. + extra_parameters: dict[str, Any] = {} + cells = self.fn.__closure__ + var_names = self.fn.__code__.co_freevars + if cells is not None: + assert len(var_names) == len(cells), "Number of free variables does not match" + for var_name, cell in zip(var_names, cells): + if var_name in parameters: + continue + # Cell content must be serializable + assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), ( + f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}" + ) + extra_parameters[var_name] = cell.cell_contents + + if isinstance(self.configs, Callable): + self.configs = self.configs(*self._kernel_parameters) + + key = self.generate_cache_key(parameters, extra_parameters) + + with self._lock: + if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): + # First check in-memory cache + if key in self._memory_cache: + # Include PrimFunc name when hitting autotuner memory cache + cached_result = self._memory_cache[key] + prim = getattr(cached_result, "func", None) + kernel_name = get_prim_func_name(prim, "") + logger.warning( + "Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.", + kernel_name, + ) + return cached_result + + # Then check disk cache + result = self._load_result_from_disk(key) + if result is not None: + # Populate memory cache with disk result + self._memory_cache[key] = result + return result + + best_latency: float = 1e8 + best_config: dict[str, Any] | None = None + best_kernel: tilelang.JITKernel | None = None + + def _compile(**config_arg) -> tilelang.JITKernel: + compile_args = self.compile_args + return compile_args.compile_program(self.fn(**config_arg)) + + if self.jit_compile is None: + self.jit_compile = _compile + + def target_fn(jit_kernel: tilelang.JITKernel): + # Unpack the context + profile_args = self.profile_args + supply_type = profile_args.supply_type + skip_check = profile_args.skip_check + manual_check_prog = profile_args.manual_check_prog + cache_input_tensors = profile_args.cache_input_tensors + ref_prog = profile_args.ref_prog + supply_prog = profile_args.supply_prog + rtol = profile_args.rtol + atol = profile_args.atol + max_mismatched_ratio = profile_args.max_mismatched_ratio + + profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type) + + # Factory functions for generating input tensors. + # This encapsulates the logic of using either a custom supply program (`supply_prog`) + # or the default profiler input generation (`profiler._get_inputs`). + def get_input_tensors_supply(with_output: bool): + def func(): + if supply_prog is not None: + return supply_prog(profiler._get_params(with_output=with_output)) + else: + return profiler._get_inputs(with_output=with_output) + + return func + + jit_input_tensors_supply = get_input_tensors_supply(with_output=False) + ref_input_tensors_supply = get_input_tensors_supply(with_output=False) + + if cache_input_tensors: + params = profiler._get_params(with_output=False) + if self.jit_input_tensors is None: + self.jit_input_tensors = jit_input_tensors_supply() + else: + # check if the cached tensors are compatible with the current configuration + assert len(params) == len(self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" + for p, c in zip(params, self.jit_input_tensors): + if not isinstance(c, torch.Tensor): + # skip non-tensor inputs checking + continue + + # Check tensor compatibility using generator expression + def shape_equal(a, b): + return all( + a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape) + ) + + if p.dtype != c.dtype or not shape_equal(p, c): + logger.warning( + "\nIncompatible input tensor properties detected between cached tensors and " + "tensors regenerated for the current configuration trial. " + "This can happen if different tuning configurations require different input shapes/dtypes " + "and input tensor caching is enabled.\n" + "To ensure fresh, compatible inputs are generated for every trial " + "you can disable caching by setting:\n" + " `cache_input_tensors=False`\n" + "within your `.set_compile_args(...)` call.\n" + ) + # otherwise, regenerate the input tensors for safety + self.jit_input_tensors = jit_input_tensors_supply() + break + else: + self.jit_input_tensors = jit_input_tensors_supply() + + if (not skip_check) and (ref_prog is not None): + if manual_check_prog is not None: + profiler.manual_assert_close(ref_prog, input_tensors=self.jit_input_tensors, manual_check_prog=manual_check_prog) + else: + profiler.assert_allclose( + ref_prog, input_tensors=self.jit_input_tensors, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio + ) + latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) + + if self.ref_latency_cache is None and ref_prog is not None: + self.ref_input_tensors = ref_input_tensors_supply() + self.ref_latency_cache = profiler.do_bench(ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) + + return latency, self.ref_latency_cache + + config_args = [] + for config in self.configs: + new_kwargs = {} + keys = config.keys() + for name, _ in parameters.items(): + if name in config: + new_kwargs[name] = config[name] + unused_keys = set(keys) - set(new_kwargs.keys()) + if len(unused_keys) > 0: + raise ValueError(f"Unused keys in config: {unused_keys}") + config_args.append(new_kwargs) + + if len(config_args) == 0: + raise ValueError("No configurations to tune, please check your `@autotune` decorator") + + # check if the tunable arguments has been set. + # get the back config argument + top_config, *rest = config_args + + if self._kernel_parameters is not None: + key_args_tuple, key_kwargs_tuple = self._kernel_parameters + tunable_arguments = [key for key, _ in top_config.items()] + + def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool: + params_list = list(parameters.keys()) + assert key in params_list, f"Tunable argument {key} not found in function parameters" + return params_list.index(key) < len(key_args_tuple) + + # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple + if any(key in top_config for key, _ in key_kwargs_tuple) or any( + check_tunable_argument_value(key, self._function_parameters, key_args_tuple) for key in tunable_arguments + ): + logger.warning( + f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" + ) + # compile the kernel with the provided parameters + jit_kernel = self.jit_compile() + autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) + self._memory_cache[key] = autotuner_result + return autotuner_result + # get the cpu count + available_cpu_count = get_available_cpu_count() + cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) + cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS) + max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) + if cpu_counts > 0: + num_workers = min(cpu_counts, available_cpu_count) + logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used") + else: + num_workers = max(1, int(available_cpu_count * cpu_utilizations)) + logger.info( + f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" + ) + + if max_cpu_count > 0 and num_workers > max_cpu_count: + logger.warning( + f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs" + ) + num_workers = max_cpu_count + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) + futures = [] + future_to_index = {} + + def cuda_device_wrapper(func, device): + def inner(**config_arg): + torch.cuda.set_device(device) + return func(**config_arg) + + return inner + + for i, config_arg in enumerate(config_args): + compile_func = self.jit_compile + + if torch.cuda.is_available(): + device = torch.cuda.current_device() + + compile_func = cuda_device_wrapper(self.jit_compile, device) + + future = pool.submit( + compile_func, + **config_arg, + ) + futures.append(future) + future_to_index[future] = i + + results_with_configs = [] + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Compiling configurations"): + idx = future_to_index[future] + config = config_args[idx] + try: + result = future.result() + results_with_configs.append((result, config)) + except Exception as e: + logger.debug(f"Compilation failed for config {config} at index {idx} with error: {e}") + continue + + ref_latency = None + progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations") + for i in progress_bar: + jit_kernel, config = results_with_configs[i] + try: + # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution + # Because tma init may behave strangely with one thread + # latency, ref_latency = target_fn(jit_kernel) + latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) + except TimeoutException: + logger.warning(f"A timeout occurred while testing config {config}, checkout autotuner.log for more details") + continue + except Exception: + logger.warning(f"An error occurred while testing config {config}, checkout autotuner.log for more details") + logger.debug(f"Error: {traceback.format_exc()}") + continue + + if latency < best_latency: + best_latency = latency + best_config = config + best_kernel = jit_kernel + + progress_bar.set_postfix({"best_latency": best_latency}) + tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}") + + pool.shutdown() + + if best_kernel is None: + error_msg = "Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation." + logger.error(error_msg) + raise RuntimeError(error_msg) + + best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result( + latency=best_latency, + config=best_config, + ref_latency=ref_latency, + ) + + autotuner_result = AutotuneResult( + latency=best_latency, + config=best_config, + ref_latency=ref_latency, + libcode=best_kernel.get_kernel_source(), + func=best_kernel.prim_func, + kernel=best_kernel, + ) + + if self.compile_args.execution_backend in ("torch"): + logger.warning("DLPack backend does not support cache saving to disk.") + else: + with self._lock: + if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): + self._save_result_to_disk(key, autotuner_result) + + self._memory_cache[key] = autotuner_result + + return autotuner_result + + def __call__(self) -> Any: + """Make the AutoTuner callable, running the auto-tuning process. + + Returns: + AutotuneResult: Results of the auto-tuning process. + """ + return self.run() + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +@dataclass +class AutoTuneImpl(Generic[_P, _T]): + jit_impl: JITImpl + + warmup: int = 25 + rep: int = 100 + timeout: int = 100 + configs: dict | Callable = None + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto + ref_prog: Callable = None + supply_prog: Callable = None + rtol: float = 1e-2 + atol: float = 1e-2 + max_mismatched_ratio: float = 0.01 + skip_check: bool = False + manual_check_prog: Callable = None + cache_input_tensors: bool = False + + def __post_init__(self): + self._tuner_cache = {} + + def get_tunner(self): + autotuner = ( + AutoTuner(self.jit_impl.func, configs=self.configs) + .set_profile_args( + supply_type=self.supply_type, + ref_prog=self.ref_prog, + supply_prog=self.supply_prog, + rtol=self.rtol, + atol=self.atol, + max_mismatched_ratio=self.max_mismatched_ratio, + skip_check=self.skip_check, + manual_check_prog=self.manual_check_prog, + cache_input_tensors=self.cache_input_tensors, + ) + .set_compile_args( + out_idx=self.jit_impl.out_idx, + execution_backend=self.jit_impl.execution_backend, + target=self.jit_impl.target, + target_host=self.jit_impl.target_host, + verbose=self.jit_impl.verbose, + pass_configs=self.jit_impl.pass_configs, + ) + ) + autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout) + return autotuner + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: + key_args_tuple = args + key_kwargs_tuple = tuple(sorted(kwargs.items())) + key = (key_args_tuple, key_kwargs_tuple) + if key not in self._tuner_cache: + + def jit_compile(**config_arg): + return self.jit_impl(*args, **kwargs, __tune_params=config_arg) + + autotuner = self.get_tunner() + autotuner.jit_compile = jit_compile + autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) + artifact = autotuner.run() + self._tuner_cache[key] = artifact.kernel + return self._tuner_cache[key] + + +def autotune( # This is the new public interface + func: Callable[_P, _T] | PrimFunc | None = None, + *, # Indicates subsequent arguments are keyword-only + configs: dict | Callable, + # profile arguments + warmup: int = 25, + rep: int = 100, + timeout: int = 100, + # compile arguments + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False, +): + """ + Just-In-Time (JIT) compiler decorator for TileLang functions. + + This decorator can be used without arguments (e.g., `@tilelang.jit`): + Applies JIT compilation with default settings. + + Tips: + - If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature. + ```python + if enable_autotune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + else: + kernel = flashattn( + batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + ``` + + Parameters + ---------- + func_or_out_idx : Any, optional + If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. + If using `@tilelang.jit` directly on a function, this argument is implicitly + the function to be decorated (and `out_idx` will be `None`). + configs : Dict or Callable + Configuration space to explore during auto-tuning. + warmup : int, optional + Number of warmup iterations before timing. + rep : int, optional + Number of repetitions for timing measurements. + timeout : int, optional + target : Union[str, Target], optional + Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". + target_host : Union[str, Target], optional + Target host for cross-compilation. Defaults to None. + execution_backend : Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + Backend for kernel execution and argument passing. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). + verbose : bool, optional + Enables verbose logging during compilation. Defaults to False. + pass_configs : Optional[Dict[str, Any]], optional + Configurations for TVM's pass context. Defaults to None. + debug_root_path : Optional[str], optional + Directory to save compiled kernel source for debugging. Defaults to None. + + Returns + ------- + Callable + Either a JIT-compiled wrapper around the input function, or a configured decorator + instance that can then be applied to a function. + """ + if callable(func): + # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults) + # This is a placeholder for a real auto tuner implementation + raise ValueError("Use tilelang.autotune to decorate func without arguments is not supported yet.") + elif isinstance(func, PrimFunc): + raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") + else: + + def decorator(impl): + assert isinstance(impl, JITImpl), "The @autotune decorator can only be applied to @tilelang.jit decorated instances." + return AutoTuneImpl( + jit_impl=impl, + configs=configs, + warmup=warmup, + rep=rep, + timeout=timeout, + supply_type=supply_type, + ref_prog=ref_prog, + supply_prog=supply_prog, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + skip_check=skip_check, + manual_check_prog=manual_check_prog, + cache_input_tensors=cache_input_tensors, + ) + + return decorator diff --git a/tilelang/original/tilelang/cache/__init__.py b/tilelang/original/tilelang/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18ac847bf4adbad368b8f853cb152287415cdb9f --- /dev/null +++ b/tilelang/original/tilelang/cache/__init__.py @@ -0,0 +1,59 @@ +"""The cache utils with class and database persistence - Init file""" + +from __future__ import annotations + +from typing import Literal +from tvm.target import Target +from tvm.tir import PrimFunc +from tilelang.jit import JITKernel +from tilelang import env +from .kernel_cache import KernelCache + +# Create singleton instance of KernelCache +_kernel_cache_instance = KernelCache() + + +def cached( + func: PrimFunc = None, + out_idx: list[int] = None, + *args, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] | None = "auto", + verbose: bool | None = False, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> JITKernel: + """ + Caches and reuses compiled kernels (using KernelCache class). + """ + return _kernel_cache_instance.cached( + func, + out_idx, + *args, + target=target, + target_host=target_host, + execution_backend=execution_backend, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + +def clear_cache(): + """ + Disabled helper that previously removed the entire kernel cache. + + Raises: + RuntimeError: Always raised to warn users to clear the cache manually. + """ + cache_dir = env.TILELANG_CACHE_DIR + raise RuntimeError( + "tilelang.clear_cache() is disabled because deleting the cache directory " + "is dangerous. If you accept the risk, remove it manually with " + f"`rm -rf '{cache_dir}'`." + ) + + +if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): + clear_cache() diff --git a/tilelang/original/tilelang/cache/kernel_cache.py b/tilelang/original/tilelang/cache/kernel_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6a5591b070e5ad96190c7e08fbdba825ae4202 --- /dev/null +++ b/tilelang/original/tilelang/cache/kernel_cache.py @@ -0,0 +1,505 @@ +"""The cache utils with class and database persistence - KernelCache Class""" + +from __future__ import annotations + +import json +import logging +import os +import shutil +import threading +import uuid +from hashlib import sha256 +from typing import Callable, Literal + +import cloudpickle +from tvm.target import Target +from tvm.tir import PrimFunc +from tvm.runtime import Executable +from tilelang.engine.param import KernelParam +from tilelang.utils.language import get_prim_func_name +from tilelang import env +from tilelang.jit import JITKernel +from tilelang import __version__ + +DEVICE_KERNEL_PATH = "device_kernel.cu" +HOST_KERNEL_PATH = "host_kernel.cu" +EXECUTABLE_PATH = "executable.so" +KERNEL_LIB_PATH = "kernel_lib.so" +KERNEL_CUBIN_PATH = "kernel.cubin" +KERNEL_PY_PATH = "kernel.py" +PARAMS_PATH = "params.pkl" + +# CuTeDSL C++ launcher specific +LAUNCHER_LIB_PATH = "launcher_lib.so" +LAUNCHER_CPP_PATH = "launcher.cpp" +CUTEDSL_CUBIN_PATH = "kernel.cubin" + + +class KernelCache: + """ + Caches compiled kernels using a class and database persistence to avoid redundant compilation. + Cache files: + kernel.cu: The compiled kernel source code + wrapped_kernel.cu: The compiled wrapped kernel source code + kernel_lib.so: The compiled kernel library + params.pkl: The compiled kernel parameters + """ + + _instance = None # For implementing singleton pattern + _lock = threading.Lock() # For thread safety + _memory_cache = {} # In-memory cache dictionary + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi" + + def __new__(cls): + """ + Implements singleton pattern for KernelCache class. + + Returns: + KernelCache: The singleton instance of KernelCache. + """ + if cls._instance is None: + with cls._lock: + if cls._instance is None: # Double-checked locking + instance = super().__new__(cls) + KernelCache._create_dirs() + instance.logger = logging.getLogger(__name__) + instance.logger.setLevel(logging.DEBUG) + instance._memory_cache = {} # Initialize memory cache + cls._instance = instance + return cls._instance + + @staticmethod + def _create_dirs(): + os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True) + os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) + + def _generate_key( + self, + func: Callable, + out_idx: list[int], + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", + args=None, + target: str | Target = "auto", + target_host: str | Target = None, + pass_configs: dict = None, + compile_flags: list[str] | str | None = None, + ) -> str: + """ + Generates a unique hash key for caching compiled kernels. + + Args: + func (Callable): The function to be compiled. + out_idx (List[int]): Indices specifying which outputs to return. + execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi". + args: Arguments passed to the function. + target (Union[str, Target]): Compilation target platform. Defaults to "auto". + target_host (Union[str, Target], optional): Host target platform. + + Returns: + str: SHA256 hash key for the kernel configuration. + """ + self.execution_backend = execution_backend + func_binary = cloudpickle.dumps(func.script(show_meta=True)) + key_data = { + "version": __version__, + "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key + "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]), + "args_repr": tuple(repr(arg) for arg in args), # Use repr to serialize arguments, may need more robust serialization + "target": str(target), + "target_host": str(target_host) if target_host else None, + "execution_backend": execution_backend, + "pass_configs": pass_configs, + "compile_flags": compile_flags, + } + # Sort keys to ensure consistency + key_string = json.dumps(key_data, sort_keys=True) + # Use SHA256 to generate hash key + return sha256(key_string.encode()).hexdigest() + + def cached( + self, + func: PrimFunc = None, + out_idx: list[int] = None, + *args, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto", + verbose: bool = False, + pass_configs: dict = None, + compile_flags: list[str] | str | None = None, + ) -> JITKernel: + """ + Caches and reuses compiled kernels to avoid redundant compilation. + + Args: + func: Function to be compiled or a prepared PrimFunc + out_idx: Indices specifying which outputs to return + target: Compilation target platform + target_host: Host target platform + *args: Arguments passed to func + + Returns: + JITKernel: The compiled kernel, either freshly compiled or from cache + """ + # Normalize target and resolve execution backend before proceeding + from tilelang.utils.target import determine_target as _determine_target + from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + + norm_target = Target(_determine_target(target)) if isinstance(target, str) else target + requested_backend = execution_backend + execution_backend = resolve_execution_backend(requested_backend, norm_target) + if verbose: + allowed_now = allowed_backends_for_target(norm_target, include_unavailable=False) + # Avoid duplicate logs when caller already resolved explicitly + if requested_backend in (None, "auto") or requested_backend != execution_backend: + self.logger.info( + "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", + execution_backend, + requested_backend, + norm_target.kind.name, + ", ".join(sorted(allowed_now)), + ) + + if not env.is_cache_enabled(): + return JITKernel( + func, + out_idx=out_idx, + execution_backend=execution_backend, + target=norm_target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + key = self._generate_key( + func=func, + out_idx=out_idx, + execution_backend=execution_backend, + args=args, + target=norm_target, + target_host=target_host, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + with self._lock: + # First check in-memory cache + if key in self._memory_cache: + # Include kernel name for easier debugging when hitting memory cache + kernel_name = get_prim_func_name(func, "") + self.logger.warning( + "Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.", + kernel_name, + ) + return self._memory_cache[key] + + if verbose: + self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '')}") + + # Then check disk cache + kernel = self._load_kernel_from_disk( + key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose + ) + if kernel is not None: + if verbose: + self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '')}") + # Populate memory cache with disk result + self._memory_cache[key] = kernel + return kernel + + if verbose: + self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '')}") + # Compile kernel if cache miss; leave critical section + kernel = JITKernel( + func, + out_idx=out_idx, + execution_backend=execution_backend, + target=norm_target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + with self._lock: + if env.is_cache_enabled(): + cache_path = self._get_cache_path(key) + self._save_kernel_to_disk(key, kernel, func, verbose) + # Set cache path on adapter so it can save cubin after first execution + if hasattr(kernel, "adapter") and execution_backend == "cutedsl": + kernel.adapter._cache_path = cache_path + + # Store in memory cache after compilation + self._memory_cache[key] = kernel + return kernel + + def clear_cache(self): + """ + Clears the entire kernel cache, including both in-memory and disk cache. + """ + with self._lock: + self._memory_cache.clear() # Clear in-memory cache + self._clear_disk_cache() # Clear disk cache + + def _get_cache_path(self, key: str) -> str: + """ + Gets the filesystem path for a cached kernel. + + Args: + key (str): The hash key identifying the kernel. + + Returns: + str: Absolute path to the cache directory for this kernel. + """ + return os.path.join(env.TILELANG_CACHE_DIR, key) + + @staticmethod + def _load_binary(path: str): + with open(path, "rb") as file: + binary = file.read() + return binary + + @staticmethod + def _safe_write_file(path: str, mode: str, operation: Callable): + # Random a temporary file within the same FS as the cache directory + temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}") + with open(temp_path, mode) as temp_file: + operation(temp_file) + + # Use atomic POSIX replace, so other processes cannot see a partial write + os.replace(temp_path, path) + + @staticmethod + def _safe_write_executable(executable: Executable, path: str): + temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}.so") + executable.export_library(temp_path) + os.replace(temp_path, path) + + def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None, verbose: bool = False): + """ + Persists a compiled kernel to disk cache. + + Args: + key (str): The hash key identifying the kernel. + kernel (JITKernel): The compiled kernel to be saved. + func (Callable, optional): The original function. + verbose (bool): Enable verbose log messages. + + Note: + Saves the following files: + - kernel.cu: The compiled kernel source code + - wrapped_kernel.cu: The wrapped kernel source code + - kernel_lib.so: The compiled kernel library + - params.pkl: The serialized kernel parameters + """ + cache_path = self._get_cache_path(key) + os.makedirs(cache_path, exist_ok=True) # Ensure directory exists + + # Save kernel source code + try: + if self.execution_backend != "cutedsl": + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + if verbose: + self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") + if kernel.kernel_source is not None: + KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) + except Exception: + self.logger.exception("Error saving kernel source code to disk") + + # Save wrapped kernel source code + try: + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH if self.execution_backend != "cutedsl" else KERNEL_PY_PATH) + if verbose: + self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") + if self.execution_backend == "tvm_ffi": + KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source())) + else: + KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source())) + except Exception: + self.logger.exception("Error saving host kernel source code to disk") + + # Save the kernel library + try: + # Save CUBIN or SO file + if self.execution_backend == "cutedsl": + # For CuTeDSL, kernel_lib_path is the Python module + kernel_lib_path = os.path.join(cache_path, KERNEL_PY_PATH) + + # Save C++ launcher library if it exists + lib_gen = getattr(kernel.adapter, "lib_generator", None) + if lib_gen and hasattr(lib_gen, "launcher_libpath") and lib_gen.launcher_libpath: + launcher_lib_path = os.path.join(cache_path, LAUNCHER_LIB_PATH) + src_launcher_path = lib_gen.launcher_libpath + if verbose: + self.logger.debug(f"Saving C++ launcher library to cache: {src_launcher_path}") + KernelCache._safe_write_file( + launcher_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_launcher_path)) + ) + + # Optionally save launcher C++ source for debugging + if hasattr(kernel.adapter, "launcher_cpp_code") and kernel.adapter.launcher_cpp_code: + launcher_cpp_path = os.path.join(cache_path, LAUNCHER_CPP_PATH) + if verbose: + self.logger.debug(f"Saving C++ launcher source to: {launcher_cpp_path}") + KernelCache._safe_write_file(launcher_cpp_path, "w", lambda file: file.write(kernel.adapter.launcher_cpp_code)) + + else: + if self.execution_backend == "nvrtc": + kernel_lib_path = KERNEL_CUBIN_PATH + elif self.execution_backend == "tvm_ffi": + kernel_lib_path = EXECUTABLE_PATH + else: + kernel_lib_path = KERNEL_LIB_PATH + kernel_lib_path = os.path.join(cache_path, kernel_lib_path) + + # Save an extra Python file for NVRTC + if self.execution_backend == "nvrtc": + src_lib_path = kernel.adapter.libpath + kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) + src_lib_path = src_lib_path.replace(".cubin", ".py") + if verbose: + self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") + KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) + + if self.execution_backend == "tvm_ffi": + executable = kernel.adapter.executable + if verbose: + self.logger.debug(f"Saving kernel executable to file: {executable}") + KernelCache._safe_write_executable(executable, kernel_lib_path) + else: + src_lib_path = kernel.adapter.libpath + if verbose: + self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) + + except Exception: + self.logger.exception("Error saving kernel library to disk") + + # Save kernel parameters + try: + params_path = os.path.join(cache_path, PARAMS_PATH) + if verbose: + self.logger.debug(f"Saving kernel parameters to disk: {params_path}") + KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) + except Exception: + self.logger.exception("Error saving kernel parameters to disk") + + def _load_kernel_from_disk( + self, + key: str, + target: str | Target = "auto", + target_host: str | Target | None = None, + out_idx: list[int] | None = None, + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, + func: Callable | None = None, + verbose: bool = False, + ) -> JITKernel | None: + """ + Loads a previously compiled kernel from disk cache. + + Args: + key (str): The hash key identifying the kernel. + target (Union[str, Target]): Compilation target platform. Defaults to "auto". + target_host (Union[str, Target], optional): Host target platform. + out_idx (List[int], optional): Indices specifying which outputs to return. + execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi". + pass_configs (dict, optional): Configuration for compiler passes. + func (Callable, optional): The original function. + verbose (bool): Enable verbose log messages. + + Returns: + JITKernel: The loaded kernel if found, None otherwise. + """ + cache_path = self._get_cache_path(key) + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + if self.execution_backend == "nvrtc": + kernel_lib_path = KERNEL_CUBIN_PATH + elif self.execution_backend == "tvm_ffi": + kernel_lib_path = EXECUTABLE_PATH + elif self.execution_backend == "cutedsl": + kernel_lib_path = KERNEL_PY_PATH + else: + kernel_lib_path = KERNEL_LIB_PATH + kernel_lib_path = os.path.join(cache_path, kernel_lib_path) + params_path = os.path.join(cache_path, PARAMS_PATH) + + # Check required files exist + required_files = [kernel_lib_path, params_path] + + # For CuTeDSL, also check launcher library + if self.execution_backend == "cutedsl": + required_files.append(os.path.join(cache_path, LAUNCHER_LIB_PATH)) + + if not all([os.path.exists(file) for file in required_files]): + return None + + device_kernel_source: str | None = None + host_kernel_source: str | None = None + kernel_params: list[KernelParam] | None = None + + # Load the kernel source file (optional) + if self.execution_backend != "cutedsl": + try: + if verbose: + self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}") + with open(device_kernel_path) as f: + device_kernel_source = f.read() + except Exception: + self.logger.exception("Error loading kernel source code from disk") + try: + if verbose: + self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") + with open(host_kernel_path) as f: + host_kernel_source = f.read() + except Exception: + self.logger.exception("Error loading host kernel source code from disk") + else: + # For CuTeDSL, set empty strings since sources aren't loaded from cache + device_kernel_source = "" + host_kernel_source = "" + + # Load kernel parameters + try: + if verbose: + self.logger.debug(f"Loading kernel parameters from file: {params_path}") + with open(params_path, "rb") as f: + kernel_params = cloudpickle.load(f) + except Exception: + self.logger.exception("Error loading kernel parameters from disk") + + if ((host_kernel_source and device_kernel_source) or self.execution_backend == "cutedsl") and kernel_params: + return JITKernel.from_database( + func=func, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + params=kernel_params, + target=target, + target_host=target_host, + out_idx=out_idx, + execution_backend=execution_backend, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + else: + # TODO(lei): report what the reason is. + return None + + def _clear_disk_cache(self): + """ + Removes all cached kernels from disk. + + Note: + This operation will delete the entire cache directory and recreate it empty. + Use with caution as this operation cannot be undone. + """ + try: + # Delete the entire cache directory + shutil.rmtree(env.TILELANG_CACHE_DIR) + + # Re-create the cache directory + KernelCache._create_dirs() + except Exception: + self.logger.exception("Error clearing disk cache") diff --git a/tilelang/original/tilelang/carver/README.md b/tilelang/original/tilelang/carver/README.md new file mode 100644 index 0000000000000000000000000000000000000000..65030d0a22ebe4e8415017bed45978b1185c35c6 --- /dev/null +++ b/tilelang/original/tilelang/carver/README.md @@ -0,0 +1,210 @@ +# Carver: A Tile-Structure Based Hint Recommend Framework for Machine Learning Compilers + +**Carver** is a lightweight framework for generating and ranking tile configurations (also known as **tiling strategies**, **blocking schemes**, or **scheduling hints**) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels. + +Carver combines hardware architecture information, user-defined tile structures, and built-in heuristics to recommend tiling strategies (or "hints"). The recommended hints are easily adaptable to multiple backends, including [TVM](https://tvm.apache.org/), [triton](https://github.com/openai/triton), [tilelang](https://github.com/tile-ai/tilelang) (or other domain-specific compilers). + +--- + +### Key Features +- **Unified Tiling Framework**: Generate tile candidates for multiple backends under a unified API. +- **Architecture-Specific Modeling**: Take into account architecture constraints (e.g., CUDA `smem_cap`, warp size, CPU cache structure, etc.) when generating hints. +- **Flexible Templates**: High-level templates (like `MatmulTemplate`, `GeneralReductionTemplate`, `ElementwiseTemplate`) let you concisely specify kernel structures. +- **Extendable**: Easily add support for new backends and new operation templates. + +--- + +## Usage Examples + +### Basic Usage: General Reduction Template + +Once installed tilelang, you can import Carver and start creating templates: + +```python +from tilelang import carver +from tilelang.carver.arch import CUDA + +# Instantiate a CUDA device object for an RTX 4090 +arch = CUDA("nvidia/geforce-rtx-4090") + +# Create a general reduction template for a loop nest: +# for i in Spatial(1024): +# for j in Spatial(1024): +# for k in Reduce(1024): +# ... +carve_template = carver.GeneralReductionTemplate( + structure="SSR", + shape=[1024, 1024, 1024], + dtype="float16", +).with_arch(arch) + +# Generate top 20 tile candidates (aka scheduling hints) +hints = carve_template.recommend_hints(topk=20) +for hint in hints: + print(hint) +``` + +**Example Output** (truncated): +```python +{ + 'block': [1, 128], + 'thread': [1, 128], + 'rstep': [64], + ... +}, +{ + 'block': [2, 64], + 'thread': [2, 64], + 'rstep': [64], + ... +}, +... +{ + 'block': [1, 16], + 'thread': [1, 16], + 'rstep': [512], + 'reduce_thread': [8], + ... +} +``` + +A tile structure composed of S and R can simulate various cases. For example, structure `SS` represents a 2D element-wise operation, while `SSR` can represent a general matrix multiplication. + +We can specialize more advanced templates to provide finer-grained information, such as `MatmulTemplate`. + + +### Matmul Template + +Carver also provides a specialized `MatmulTemplate` for matrix multiplication (e.g., `C = A * B`), automatically inferring common tiling strategies (thread blocks, warps, use of tensor cores, etc.). + +```python +from tilelang import carver +from tilelang.carver.arch import CUDA + +arch = CUDA("nvidia/geforce-rtx-4090") +carve_template = carver.MatmulTemplate( + M=1024, + N=1024, + K=1024, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", +).with_arch(arch) + +# Retrieve the (symbolic) function describing the matmul +func = carve_template.equivalent_function() +print("Equivalent Function:\n", func) + +# Generate hints +hints = carve_template.recommend_hints(topk=20) +for hint in hints: + print(hint) +``` + +**Example Output**: +```python +{ + 'block': [32, 64], + 'warp': [16, 32], + 'rstep': [128], + 'use_tc': True, + ... +}, +{ + 'block': [64, 32], + 'warp': [32, 16], + 'rstep': [128], + 'use_tc': True, + ... +}, +... +{ + 'block': [256, 32], + 'warp': [128, 16], + 'rstep': [32], + 'use_tc': True, + ... +} +``` + +--- + +## Supported Architectures + +Carver currently provides out-of-the-box support for: +- **CUDA**: e.g., `arch = CUDA("nvidia/geforce-rtx-4090")` +- **CDNA** (AMD GPU-like backends) +- **CPU** + +Adding a new architecture is as simple as implementing a new subclass of `TileDevice` or providing a custom target that describes: +- Shared/local memory capacity +- Warp (or vector) size +- Cache sizes +- Tensor instructions available + +Below is an **illustrative snippet** of the CUDA backend: +```python +class CUDA(TileDevice): + def __init__(self, target: Union[tvm.target.Target, str]): + ... + self.platform = "CUDA" + # Device constraints + self.smem_cap = device.max_shared_memory_per_block + self.compute_max_core = device.multi_processor_count + self.warp_size = device.warp_size + ... + self.transaction_size = [32, 128] # bytes + self.bandwidth = [750, 12080] # MB/s, approximate + self.available_tensor_instructions = None + + def get_avaliable_tensorintrin_shapes(self): + self.available_tensor_instructions = ( + TensorInstruction("mma", [16, 16]), + TensorInstruction("wmma", [16, 16]), + ) + return [t.shape for t in self.available_tensor_instructions] + + def __repr__(self): + return f"CUDA({self.target})" +``` + +## Adapting Hints to Other Compilers + +One of Carver’s main benefits is its adaptability. Here are a examples for triton lang: + +Given a Carver hint like: +```python +{ + 'block': [32, 64], + 'warp': [16, 32], + 'rstep': [128], + 'use_tc': True, + 'vectorize': {'A_reindex': 8, 'B_reindex': 8} +} +``` +You might interpret this in **Triton** as: +- `block_m = 32, block_n = 64, block_k = 128` +- Potential warp usage = `warp_m = 16, warp_n = 32` +- `vectorize`: load data with a vector width of 8 +- If `use_tc` is true, consider using Tensor Cores (TensorOps in Triton) if supported. + +This helps quickly test multiple configurations without manually guessing. + + + +## Supported Templates + +Carver abstracts common loop patterns through templates: +- **`GeneralReductionTemplate`**: For general `Spatial-Spatial-Reduce` (SSR) structures or similar. +- **`FlashAttentionTemplate`**: For attention-like operations with flash memory. +- **`MatmulTemplate`**: For standard matrix multiplication `C = A * B`. +- **`GEMVTemplate`**: For `y = Ax` or `y = xA` style operations. +- **`ElementwiseTemplate`**: For elementwise transformations or pointwise ops. + +You can also create your own specialized templates if you have unique loop structures or constraints. For instance, you might define specialized templates for convolution, flash attention, etc. + + +## TODO Items + +- [ ] **Adapt to tile language**: Provide ready-made scheduling calls or wrappers for [tilelang](https://github.com/LeiYanggh/tilelang) to streamline end-to-end integration. + diff --git a/tilelang/original/tilelang/carver/__init__.py b/tilelang/original/tilelang/carver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1dfc5b4750d23e679535b11c34b1bfe7ceb6576 --- /dev/null +++ b/tilelang/original/tilelang/carver/__init__.py @@ -0,0 +1,15 @@ +"""Base infra""" + +from .analysis import ( + BlockInfo, # noqa: F401 + IterInfo, # noqa: F401 + collect_block_iter_vars_used_in_access_region, # noqa: F401 + collect_vars_used_in_prim_expr, # noqa: F401 + detect_dominant_read, # noqa: F401 + is_broadcast_epilogue, # noqa: F401 + normalize_prim_func, # noqa: F401 +) # noqa: F401 +from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401 +from .roller import * +from .arch import CUDA, CDNA # noqa: F401 +from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401 diff --git a/tilelang/original/tilelang/carver/analysis.py b/tilelang/original/tilelang/carver/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca9168185f2d3149e273cbe2e3517d385eda204 --- /dev/null +++ b/tilelang/original/tilelang/carver/analysis.py @@ -0,0 +1,296 @@ +"""Analysis on TIR blocks, loops and functions.""" + +from __future__ import annotations +from typing_extensions import Literal + +from tvm import ir, tir, DataType +from tvm.ffi import get_global_func +from tvm.target.target import Target +from tvm.tir import Schedule, IterVar +from tvm.tir.schedule import BlockRV + + +class IterInfo: + """Information about a loop/iter var.""" + + kind: Literal["S", "R", "O"] + var: tir.Var + _dom: tir.PrimExpr + loop_rv: tir.schedule.LoopRV + + def __init__( + self, + kind: Literal["S", "R", "O"], + var: tir.Var, + dom: tir.PrimExpr, + loop_rv: tir.schedule.LoopRV, + ): + """Construct an IterInfo object.""" + self.kind = kind + self.var = var + self._dom = dom + self.loop_rv = loop_rv + + @property + def dom(self) -> int | tir.PrimExpr: + """The iteration domain of the loop.""" + return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom + + def __str__(self) -> str: + return f'Iter("{self.kind}", {self.dom})' + + def __repr__(self) -> str: + return str(self) + + +class BlockInfo: + """Information about a TIR block.""" + + name: str + iters: list[IterInfo] + block_rv: tir.schedule.BlockRV + _reduction_block: bool + + def __init__( + self, + name: str, + iters: list[IterInfo], + block_rv: tir.schedule.BlockRV, + reduction_block: bool = False, + ): + """Construct a BlockInfo object.""" + self.name = name + self.block_rv = block_rv + self.iters = iters + self._reduction_block = reduction_block + + def dom(self) -> list[int | tir.PrimExpr]: + """The iteration domain of the block.""" + return [i.dom for i in self.iters] + + def dom_kind(self) -> str: + """The iteration domain kind of the block, for example, SSSS, SSSR.""" + return "".join(i.kind for i in self.iters) + + def is_injective(self) -> bool: + """Whether the block is injective, i.e. all its iteration domains are injective.""" + return all(k == "S" for k in self.dom_kind()) + + def is_elementwise(self, sch: tir.Schedule) -> bool: + """Whether the block is elementwise, i.e. trivial mapping between read/write region""" + + def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: + return dom.min.same_as(var) and dom.extent == 1 + + if not self.is_injective(): + return False + block = sch.get(self.block_rv) + if len(block.reads) != 1 or len(block.writes) != 1: + return False + r_region = block.reads[0].region + w_region = block.writes[0].region + if len(r_region) != len(w_region): + return False + for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region): + if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom): + return False + return True + + def is_reduction(self) -> bool: + """Whether the block is a reduction workload.""" + # TODO(@junrushao): distinguish GEMV and reduction + return self._reduction_block + + def is_gemv(self) -> bool: + """Whether the block is a GEMV workload.""" + raise NotImplementedError + + def is_gemm(self) -> bool: + """Whether the block is a GEMM workload.""" + raise NotImplementedError + + def __str__(self) -> str: + return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})' + + def __repr__(self) -> str: + return str(self) + + +_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") + + +def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None: + """Normalize the primfunc to normal form""" + try: + result = _normalize_prim_func(sch) + if result is None: + return None + except Exception: # pylint: disable=broad-except + return None + + def _iter_kind(i: tir.IterVar) -> str: + return { + tir.IterVar.DataPar: "S", + tir.IterVar.CommReduce: "R", + }.get(i.iter_type, "O") + + blocks: list[BlockInfo] = [] + for block, loops, iters, is_reduction in zip(*result): + blocks.append( + BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter), # type: ignore + var=iter.var, + dom=iter.dom, + loop_rv=loop, + ) + for loop, iter in zip(loops, iters) + ], + block_rv=block, + reduction_block=is_reduction, + ) + ) + return blocks + + +def find_var_from_func(func, var: str): + for buffer in func.buffer_map.values(): + for i in buffer.shape: + if isinstance(i, tir.Var) and i.name == var: + return i + return None + + +def check_func_with_dynamic(func): + for buffer in func.buffer_map.values(): + for i in buffer.shape: + if isinstance(i, tir.Var): + return True + return False + + +def _assert_gpu_target(target: Target): + if "gpu" not in target.keys: + raise ValueError(f"Expect a GPU target, but got {target}") + + +def get_max_threads_per_block(target: Target) -> int: + _assert_gpu_target(target) + max_threads_per_block = None + for name in ["max_threads_per_block", "max_num_threads"]: + if max_threads_per_block is None: + max_threads_per_block = target.attrs.get(name, None) + if max_threads_per_block is None: + max_threads_per_block = 64 + return int(max_threads_per_block) + + +def get_max_shared_memory_per_block(target: Target) -> int: + _assert_gpu_target(target) + max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) + if max_shared_memory_per_block is None: + raise ValueError(f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") + return int(max_shared_memory_per_block) + + +def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: + try: + block = sch.mod[func_name].body.block + except Exception: + raise ValueError(f"The function body is expected to be the root block, but got:\n{sch.mod[func_name].body}") from None + return sch.get_block(block.name_hint) + + +def collect_block_iter_vars_used_in_access_region(block: tir.Block, region: list[ir.Range]) -> set[tir.Var]: + """Collect the block iter variables used in the access region of a buffer region.""" + tir_vars = set() + for expr in region: + if expr.extent != 1: + continue + tir_vars |= collect_vars_used_in_prim_expr(expr.min) + tir_vars &= set(iter_var.var for iter_var in block.iter_vars) + return tir_vars + + +def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> set[tir.Var]: + """Collect the variables used in the PrimExpr.""" + tir_vars = set() + + def _collect_tir_var(expr): + if isinstance(expr, tir.Var): + tir_vars.add(expr) + + tir.stmt_functor.post_order_visit(expr, _collect_tir_var) + return tir_vars + + +def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: + """Detect the dominant read indices in the block.""" + dominant_read = None + num_read_iters = -1 + for buffer_region in block.reads: + tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) + if num_read_iters < len(tir_vars): + num_read_iters = len(tir_vars) + dominant_read = buffer_region + assert dominant_read is not None + (result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region]) + return result + + +def is_broadcast_epilogue( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + epilogue: tir.schedule.BlockRV, +) -> bool: + """Check if the epilogue block is a broadcast pattern""" + write_buffers = {r.buffer for r in sch.get(block).writes} + epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom != 1} + for buffer_region in sch.get(epilogue).reads: + if buffer_region.buffer not in write_buffers: + continue + tir_vars = collect_block_iter_vars_used_in_access_region(sch.get(epilogue), buffer_region.region) + if len(tir_vars) < len(epilogue_iters): + return True + return False + + +def get_reduction_blocks(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) == 0: + return None + return reduction_blocks + + +def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: + # gpu memory prefer 128 bits coalesced access (e.g. four banks) + # 128 bits + buffers: list[tir.Buffer] = [] + for read in block_stmt.reads: + buffers.append(read.buffer) + for write in block_stmt.writes: + buffers.append(write.buffer) + # pick the dtype with the largest bits + max_dtype_bits: int = 0 + for buffer in buffers: + max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits) + return target_bits // max_dtype_bits diff --git a/tilelang/original/tilelang/carver/arch/__init__.py b/tilelang/original/tilelang/carver/arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6cb9e72f7243c46b9d13636e91b37ab887a0625 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/__init__.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from .arch_base import TileDevice +from .cuda import * +from .cpu import * +from .cdna import * +from .metal import * +from tvm.target import Target +import torch + + +def get_arch(target: str | Target = "cuda") -> TileDevice: + if isinstance(target, str): + target = Target(target) + + if target.kind.name == "cuda": + return CUDA(target) + elif target.kind.name == "llvm": + return CPU(target) + elif target.kind.name == "hip": + return CDNA(target) + elif target.kind.name == "metal": + return METAL(target) + else: + raise ValueError(f"Unsupported target: {target.kind.name}") + + +def auto_infer_current_arch() -> TileDevice: + # TODO(lei): This is a temporary solution to infer the current architecture + # Can be replaced by a more sophisticated method in the future + if torch.version.hip is not None: + return get_arch("hip") + if torch.cuda.is_available(): + return get_arch("cuda") + elif torch.mps.is_available(): + return get_arch("metal") + else: + return get_arch("llvm") + + +__all__ = [ + "is_cpu_arch", + "is_cuda_arch", + "is_volta_arch", + "is_ampere_arch", + "is_ada_arch", + "is_hopper_arch", + "is_tensorcore_supported_precision", + "has_mma_support", + "is_cdna_arch", + "is_metal_arch", + "CUDA", + "CDNA", + "METAL", + "CPU", +] diff --git a/tilelang/original/tilelang/carver/arch/arch_base.py b/tilelang/original/tilelang/carver/arch/arch_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e9dfa683c0bb357008b2748ae280b2c52ee04a --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/arch_base.py @@ -0,0 +1,30 @@ +class TileDevice: + """ + Represents the architecture of a computing device, capturing various hardware specifications. + """ + + def __init__(self) -> None: + self.reg_cap: int = 0 # Register capacity: The amount of register memory available + self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available + self.compute_max_core: int = 0 # The maximum number of computing cores + self.warp_size: int = 0 # The size of a warp, a group of threads that execute instructions in lockstep + self.sm_partition: int = 0 # The number of streaming multiprocessor partitions + self.transaction_size: list[int] = [ + 0, + 0, + ] # The size of memory transactions, typically in bytes + self.max_smem_usage: int = 0 # The maximum shared memory usage allowed + self.bandwidth: list[int] = [ + 0, + 0, + ] # Bandwidth specifications, possibly including peak and sustained rates + self.platform: str = "unknown" # The platform or manufacturer of the device + self.compute_capability: str = "unknown" # The compute capability, indicating the feature set and performance level + self.l2_cache_size_bytes: int = 0 + # the number of transaction size in bytes + self.transaction_size: list[int] = [0, 0] # in bytes + # bandwidth in MB/s, will be used for recommend basic tile size + self.bandwidth: list[int] = [0, 0] + + def get_avaliable_tensorintrin_shapes(self): + raise NotImplementedError() diff --git a/tilelang/original/tilelang/carver/arch/cdna.py b/tilelang/original/tilelang/carver/arch/cdna.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2d4c4ed6722e2577a2755d2af73ff76eabbf41 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/cdna.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import tvm +from tvm.target import Target +from .arch_base import TileDevice + + +def is_cdna_arch(arch: TileDevice) -> bool: + return isinstance(arch, CDNA) + + +class CDNA(TileDevice): + def __init__(self, target: Target | str): + if isinstance(target, str): + target = tvm.target.Target(target) + self.target = target + device = tvm.runtime.rocm(0) + if not device.exist: + raise RuntimeError("Cannot find HIP device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CDNA" + self.smem_cap = device.max_shared_memory_per_block + self.compute_max_core = device.multi_processor_count + self.warp_size = device.warp_size + self.compute_capability = device.compute_version.replace(".", "") + self.reg_cap: int = 32768 + self.max_smem_usage: int = 2 * self.smem_cap + self.sm_partition: int = 4 + self.l2_cache_size_bytes: int = target.l2_cache_size_bytes + self.transaction_size: list[int] = [32, 128] # in bytes + + self.bandwidth: list[int] = [1300, 14000] + + +__all__ = [ + "is_cdna_arch", + "CDNA", +] diff --git a/tilelang/original/tilelang/carver/arch/cpu.py b/tilelang/original/tilelang/carver/arch/cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..fc18c6c8b35e703c1efe98b4ed1378868823d741 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/cpu.py @@ -0,0 +1,25 @@ +import tvm +from tvm.target import Target +from .arch_base import TileDevice + + +def is_cpu_arch(arch: TileDevice) -> bool: + return isinstance(arch, CPU) + + +# For LLVM Backend, we do not provide the detailed information of the CPU +# As the LLVM backend do not required tuning, just maintain the consistency +class CPU(TileDevice): + def __init__(self, target: Target): + self.target = target + device = tvm.runtime.cpu(0) + if not device.exist: + raise RuntimeError("Cannot find cpu device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CPU" + + +__all__ = [ + "is_cpu_arch", + "CPU", +] diff --git a/tilelang/original/tilelang/carver/arch/cuda.py b/tilelang/original/tilelang/carver/arch/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..2b79b2832b223469f5f3aa8b6fc94d1448acd2c3 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/cuda.py @@ -0,0 +1,156 @@ +from __future__ import annotations +import tvm +from tvm.target import Target +from .arch_base import TileDevice +from .driver import cuda_driver + + +def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + +def is_cuda_arch(arch: TileDevice) -> bool: + return isinstance(arch, CUDA) + + +def is_volta_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 70) + conditions.append(arch.sm_version < 80) + return all(conditions) + + +def is_ampere_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 80 and arch.sm_version < 89) + return all(conditions) + + +def is_ada_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version == 89) + return all(conditions) + + +def is_hopper_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version == 90) + return all(conditions) + + +def has_mma_support(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 80) + return all(conditions) + + +volta_tensorcore_supported = [ + ("float16", "float32"), + ("float16", "float16"), +] +ampere_tensorcore_supported = [ + ("bfloat16", "float32"), + ("float16", "float32"), + ("float16", "float16"), + ("int8", "int32"), + ("int4", "int32"), + ("int2", "int32"), + ("int1", "int32"), +] +ada_tensorcore_supported = [ + ("bfloat16", "float32"), + ("float16", "float32"), + ("float16", "float16"), + ("int8", "int32"), + ("float8_e5m2", "float32"), + ("float8_e4m3", "float32"), +] +hopper_tensorcore_supported = ada_tensorcore_supported + + +# TODO(lei): we should consider the dtype of the input a and b +# instead of assuming both a and b share the same dtype. +# As the tensorcore may supports float8_e4m3 * float8_e5m2 +def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: + if is_volta_arch(arch): + return (in_dtype, accum_dtype) in volta_tensorcore_supported + elif is_ampere_arch(arch): + return (in_dtype, accum_dtype) in ampere_tensorcore_supported + elif is_ada_arch(arch): + return (in_dtype, accum_dtype) in ada_tensorcore_supported + elif is_hopper_arch(arch): + return (in_dtype, accum_dtype) in hopper_tensorcore_supported + else: + raise ValueError(f"Unsupported architecture: {arch}") + + +class TensorInstruction: + def __init__( + self, + name: str, + shape: list[int], + ): + self.name: str = name + # only hold the shape of M and N + self.shape: list[int] = shape + + +class CUDA(TileDevice): + def __init__(self, target: Target | str): + if isinstance(target, str): + target = tvm.target.Target(target) + self.target = target + self.sm_version = check_sm_version(self.target.arch) + device = tvm.runtime.cuda(0) + if not device.exist: + raise RuntimeError("Cannot find cuda device 0.") + self.name = cuda_driver.get_device_name() + self.device: tvm.runtime.Device = device + self.platform: str = "CUDA" + # TODO(lei): maybe static shared memory, can be improved in future + self.smem_cap = cuda_driver.get_shared_memory_per_block() + self.compute_max_core = device.multi_processor_count + self.warp_size = device.warp_size + self.compute_capability = device.compute_version.replace(".", "") + self.reg_cap: int = 65536 + self.max_smem_usage: int = 2 * self.smem_cap + self.sm_partition: int = 4 + self.l2_cache_size_bytes: int = target.l2_cache_size_bytes + # the number of transaction size in bytes + self.transaction_size: list[int] = [32, 128] # in bytes + # bandwidth in MB/s, will be used for recommend basic tile size + # TODO(lei): find some way to get the real bandwidth + # However, the ratio of bandwidth between different devices can + # be similar. The bandwidth can work for another devices as well. + self.bandwidth: list[int] = [750, 12080] + # get the available tensor instructions during runtime to avoid + # the dependency of the tensor intrinsics registration + self.available_tensor_instructions: list[TensorInstruction] = None + + def get_avaliable_tensorintrin_shapes(self): + self.available_tensor_instructions = ( + TensorInstruction("mma", [16, 16]), + TensorInstruction("wmma", [16, 16]), + ) + return [t.shape for t in self.available_tensor_instructions] + + def __repr__(self): + return f"CUDA({self.target})" + + +__all__ = [ + "is_cuda_arch", + "is_volta_arch", + "is_ampere_arch", + "is_ada_arch", + "is_hopper_arch", + "is_tensorcore_supported_precision", + "has_mma_support", + "CUDA", +] diff --git a/tilelang/original/tilelang/carver/arch/driver/__init__.py b/tilelang/original/tilelang/carver/arch/driver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b380f4763801c64b57823894be7de31658c56f1 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/driver/__init__.py @@ -0,0 +1,10 @@ +from .cuda_driver import ( + get_cuda_device_properties, # noqa: F401 + get_device_name, # noqa: F401 + get_shared_memory_per_block, # noqa: F401 + get_device_attribute, # noqa: F401 + get_max_dynamic_shared_size_bytes, # noqa: F401 + get_persisting_l2_cache_max_size, # noqa: F401 + get_num_sms, # noqa: F401 + get_registers_per_block, # noqa: F401 +) diff --git a/tilelang/original/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/original/tilelang/carver/arch/driver/cuda_driver.py new file mode 100644 index 0000000000000000000000000000000000000000..a631276635f6d53df368271a65ab6c84926c1f62 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/driver/cuda_driver.py @@ -0,0 +1,129 @@ +from __future__ import annotations +import ctypes +import sys + +try: + import torch.cuda._CudaDeviceProperties as _CudaDeviceProperties +except ImportError: + _CudaDeviceProperties = type("DummyCudaDeviceProperties", (), {}) + + +class cudaDeviceAttrNames: + r""" + refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd + """ + + cudaDevAttrMaxThreadsPerBlock: int = 1 + cudaDevAttrMaxRegistersPerBlock: int = 12 + cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 + cudaDevAttrMaxPersistingL2CacheSize: int = 108 + + +def get_cuda_device_properties(device_id: int = 0) -> _CudaDeviceProperties | None: + try: + import torch.cuda + + if not torch.cuda.is_available(): + return None + return torch.cuda.get_device_properties(torch.device(device_id)) + except ImportError: + return None + + +def get_device_name(device_id: int = 0) -> str | None: + prop = get_cuda_device_properties(device_id) + if prop: + return prop.name + + +def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: + assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" + prop = get_cuda_device_properties(device_id) + if prop is None: + raise RuntimeError("Failed to get device properties.") + shared_mem = int(prop.shared_memory_per_block) + if format == "bytes": + return shared_mem + elif format == "kb": + return shared_mem // 1024 + elif format == "mb": + return shared_mem // (1024 * 1024) + else: + raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") + + +def get_device_attribute(attr: int, device_id: int = 0) -> int: + try: + if sys.platform == "win32": + libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll") + else: + libcudart = ctypes.cdll.LoadLibrary("libcudart.so") + + value = ctypes.c_int() + cudaDeviceGetAttribute = libcudart.cudaDeviceGetAttribute + cudaDeviceGetAttribute.argtypes = [ + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ctypes.c_int, + ] + cudaDeviceGetAttribute.restype = ctypes.c_int + + ret = cudaDeviceGetAttribute(ctypes.byref(value), attr, device_id) + if ret != 0: + raise RuntimeError(f"cudaDeviceGetAttribute failed with error {ret}") + + return value.value + except Exception as e: + print(f"Error getting device attribute: {str(e)}") + return None + + +def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> int | None: + """ + Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. + """ + assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" + shared_mem = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) + if format == "bytes": + return shared_mem + elif format == "kb": + return shared_mem // 1024 + elif format == "mb": + return shared_mem // (1024 * 1024) + else: + raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") + + +def get_persisting_l2_cache_max_size(device_id: int = 0) -> int: + prop = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize, device_id) + return prop + + +def get_num_sms(device_id: int = 0) -> int: + """ + Get the number of streaming multiprocessors (SMs) on the CUDA device. + + Args: + device_id (int, optional): The CUDA device ID. Defaults to 0. + + Returns: + int: The number of SMs on the device. + + Raises: + RuntimeError: If unable to get the device properties. + """ + prop = get_cuda_device_properties(device_id) + if prop is None: + raise RuntimeError("Failed to get device properties.") + return prop.multi_processor_count + + +def get_registers_per_block(device_id: int = 0) -> int: + """ + Get the maximum number of 32-bit registers available per block. + """ + prop = get_device_attribute( + cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock, + device_id, + ) + return prop diff --git a/tilelang/original/tilelang/carver/arch/metal.py b/tilelang/original/tilelang/carver/arch/metal.py new file mode 100644 index 0000000000000000000000000000000000000000..0b76849a7695175a113b14cde3db19605de37006 --- /dev/null +++ b/tilelang/original/tilelang/carver/arch/metal.py @@ -0,0 +1,20 @@ +from __future__ import annotations +from tvm.target import Target +from .arch_base import TileDevice + + +def is_metal_arch(arch: TileDevice) -> bool: + return isinstance(arch, METAL) + + +class METAL(TileDevice): + def __init__(self, target: Target | str): + if isinstance(target, str): + target = Target(target) + self.target = target + + +__all__ = [ + "is_metal_arch", + "METAL", +] diff --git a/tilelang/original/tilelang/carver/common_schedules.py b/tilelang/original/tilelang/carver/common_schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..4904b770dd6a5c83e219fea80061ed4f9fb5b46b --- /dev/null +++ b/tilelang/original/tilelang/carver/common_schedules.py @@ -0,0 +1,164 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm common_schedules.py in dlight. +"""Common schedule strategies for TIR.""" + +from typing import Callable + +from tvm import tir +from .utils import retrieve_func_from_module +from .analysis import BlockInfo + + +def get_block( + sch: tir.Schedule, + blocks: list[BlockInfo], + name: str, +): + """Get the target block from a schedule. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to get target block. + name : str + The name of the target block. + + Returns + ------- + target_block : BlockRV + The target block. + """ + + target_block: tir.BlockRV = None + for block_info in blocks: + block = block_info.block_rv + if sch.get(block).name_hint == name: + target_block = block + return target_block + + +def get_output_blocks( + sch: tir.Schedule, + blocks: list[BlockInfo], +): + """Get the output blocks of a schedule. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to get output blocks. + blocks : List[BlockInfo] + The blocks to be analyzed. + + Returns + ------- + output_blocks : List[BlockInfo] + The output blocks. + """ + + # collect arguments buffer + func = retrieve_func_from_module(sch.mod) + args = list(func.buffer_map.values()) + + output_blocks = [] + for block_info in blocks: + block = block_info.block_rv + for write in sch.get(block).writes: + if write.buffer in args: + output_blocks.append(block) + + return output_blocks + + +def try_inline( + sch: tir.Schedule, + blocks: list[BlockInfo], +) -> list[BlockInfo]: + """Try to inline as many blocks as possible, and return the remaining blocks. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to inline blocks. + blocks : List[BlockInfo] + The blocks to be inlined. + + Returns + ------- + remaining : List[BlockInfo] + The remaining blocks that cannot be inlined. + """ + + def _trial(func: Callable): + for i, block in enumerate(blocks): + try: + func(block.block_rv) + except Exception: # pylint: disable=bare-except + continue + return i + return None + + while True: + i = _trial(sch.compute_inline) + if i is None: + i = _trial(sch.reverse_compute_inline) + if i is None: + break + blocks.pop(i) + return blocks + + +def try_inline_contiguous_spatial( + sch: tir.Schedule, + block_infos: list[BlockInfo], +) -> list[BlockInfo]: + """Try to inline contiguous spatial blocks in a schedule + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to inline blocks. + block_infos : List[BlockInfo] + The blocks to be try. + + Returns + ------- + remaining : List[BlockInfo] + The remaining blocks that cannot be inlined. + """ + + if block_infos is None: + return None + results = [] + spatial_blocks = [] + block: BlockInfo + for block in block_infos: + if block.is_injective(): + spatial_blocks.append(block) + elif spatial_blocks: + results.extend(try_inline(sch, spatial_blocks)) + results.append(block) + spatial_blocks = [] + else: + results.append(block) + if spatial_blocks: + results.extend(try_inline(sch, spatial_blocks)) + return results diff --git a/tilelang/original/tilelang/carver/matmul_analysis.py b/tilelang/original/tilelang/carver/matmul_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..6d27de8253c12cdebc3ba7f50ec66cb702148b63 --- /dev/null +++ b/tilelang/original/tilelang/carver/matmul_analysis.py @@ -0,0 +1,829 @@ +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" + +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from tvm import tir +from tvm.ir import Range +from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap +from tvm.tir.analysis import undefined_vars +from tvm.tir.schedule.schedule import BlockRV +from .analysis import ( + collect_block_iter_vars_used_in_access_region, + get_root_block, + get_reduction_blocks, +) +from tvm.target.target import Target +from tvm.tir.stmt_functor import pre_order_visit +from .arch import get_arch, is_tensorcore_supported_precision +import logging + +logger = logging.getLogger(__name__) + + +def collect_vars_from_expr(prim_expr): + vars = [] + + def callback(node): + if isinstance(node, Var): + vars.append(node) + return True + + pre_order_visit(prim_expr, callback) + + return vars + + +def _is_one(x: PrimExpr) -> bool: + return isinstance(x, tir.IntImm) and x.value == 1 + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def auto_inline_producers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + skip_blocks: list[tir.schedule.BlockRV] | None = None, +): + skip_blocks = skip_blocks or [] + while True: + inlined_cnt = 0 + producers = _collect_producers(sch, block) + for producer in producers: + if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): + continue + try: + sch.compute_inline(producer) + inlined_cnt += 1 + except Exception: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + while True: + inlined_cnt = 0 + consumers = _collect_consumers(sch, block) + for consumer in consumers: + try: + sch.compute_inline(consumer) + inlined_cnt += 1 + except Exception: # pylint: disable=bare-except + continue + for consumer in consumers: + try: + sch.reverse_compute_inline(consumer) + inlined_cnt += 1 + except Exception: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumer_chain( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + auto_inline_consumers(sch, block) + remaining_consumers = sch.get_consumers(block) + + if len(remaining_consumers) != 0: + # Some blocks have failed to be inlined to the producer cache-write stage. + # This could be due to another producer block that has not been scheduled. + for c in remaining_consumers: + for p in sch.get_producers(c): + if sch.get(p) != sch.get(block): + sch.compute_inline(p) + + # Try inlining into the cache-write stage again, this time it should succeed. + auto_inline_consumers(sch, block) + + +# used to match the similar region with dequantize op. +def find_first_similar_region(regions: list[BufferRegion], buffer: tir.Buffer): + for region in regions: + if len(region.buffer.shape) == len(buffer.shape): + return region + return None + + +# used to match the similar buffer with dequantize op. +def find_first_similar_buffer(regions: list[BufferRegion], buffer: tir.Buffer): + for region in regions: + if len(region.buffer.shape) == len(buffer.shape): + return region.buffer + return None + + +# find the block that required to be reindex and scope. +def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> BlockRV | None: + # block that most near to the arguments + block = main_block + buffer = buffer + + while True: + last_buffer = buffer + producers = sch.get_producers(block) + + if len(producers) == 0: + # do not have any producer means it is the first block + break + + for producer in producers: + for write in sch.get(producer).writes: + if write.buffer == buffer: + block = producer + buffer = find_first_similar_buffer(sch.get(producer).reads, last_buffer) + if buffer == last_buffer: + break + return block + + +def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer) -> int: + """traverse to find the arg index from the buffer""" + producers = sch.get_producers(main_block) + + # a head buffer has no producer blocks + def find_args_index(sch: tir.Schedule, buffer: tir.Buffer): + for i, param in enumerate(sch.mod["main"].params): + if sch.mod["main"].buffer_map[param] == buffer: + return i + return None + + is_head_buffer = len(producers) == 0 + if is_head_buffer: + return find_args_index(sch, buffer) + for block in sch.get_producers(main_block): + if len(sch.get(block).reads) != 1 or len(sch.get(block).writes) != 1: + continue + for write in sch.get(block).writes: + if write.buffer == buffer: + return find_arg_idx_from_buffer_chain(sch, block, buffer) + + # if no buffer producer block found, it means the buffer is an input buffer + return find_args_index(sch, buffer) + + +class IterKind(Enum): + """Iter kinds for GEMM-liked programs. + We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], + where `I, J, K` are fundamental axes for gemm and `S` represents all + other spatial axes (e.g. batches) + kIter_S: spatial axes + kIter_I: I axes + kIter_J: J axes + kIter_K: K axes + kIter_T: trivial axes (i.e. with extent 1) + """ + + kIter_S = 0 + kIter_I = 1 + kIter_J = 2 + kIter_K = 3 + kIter_T = 4 + + +@dataclass +class IterTrait: + kind: IterKind + extent: PrimExpr + + +def make_iter_fusion_index_map( + traits: list[IterTrait], + kind_order: list[IterKind], +) -> tir.IndexMap: + fused_iters: dict[IterKind, PrimExpr] = {} + input_iters: list[tir.Var] = [] + for i, trait in enumerate(traits): + v_i = tir.Var(f"i{i}", trait.extent.dtype) + input_iters.append(v_i) + if trait.kind == IterKind.kIter_T: + continue + if trait.kind not in kind_order: + raise ValueError(f"Unknown iter kind {trait.kind}") + if trait.kind in fused_iters: + fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i + else: + fused_iters[trait.kind] = v_i + + final_indices: list[tir.PrimExpr] = [fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order] + + return tir.IndexMap(input_iters, final_indices, None) + + +def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None: + """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + Returns + ------- + traits : Optional[Tuple[List[IterTrait]]] + The detected iter traits for axes in A, B and C. None if the block + does not match the pattern. + + """ + + if len(block.reads) != 2 or len(block.writes) != 1: + return None + + def get_access_axes(region: list[Range]) -> set[Var]: + axes: set[Var] = set() + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes = axes.union(set(undefined_vars(r.min))) + return axes + + try: + A_axes = get_access_axes(block.reads[0].region) + B_axes = get_access_axes(block.reads[1].region) + C_axes = get_access_axes(block.writes[0].region) + except ValueError: + return None + + traits: dict[Var, IterTrait] = {} + for iter_var in block.iter_vars: + var = iter_var.var + kind: IterKind + if _is_one(iter_var.dom.extent): + if iter_var.iter_type == tir.IterVar.CommReduce: + # for simplified case (e.g. 1x1 conv kernel) + kind = IterKind.kIter_K + else: + kind = IterKind.kIter_T + elif iter_var.iter_type == iter_var.DataPar: + if var in A_axes and var in B_axes and var in C_axes: + kind = IterKind.kIter_S + elif var in A_axes and var in C_axes: + kind = IterKind.kIter_I + elif var in B_axes and var in C_axes: + kind = IterKind.kIter_J + else: + return None + elif iter_var.iter_type == tir.IterVar.CommReduce: + if var in A_axes and var in B_axes and var not in C_axes: + kind = IterKind.kIter_K + else: + return None + else: + return None + traits[var] = IterTrait(kind, iter_var.dom.extent) + + # A Gemm-kernel requires have I, J and K axes + gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K} + if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: + return None + + A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] + B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] + C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] + block_traits = [traits[i.var] for i in block.iter_vars] + return A_traits, B_traits, C_traits, block_traits + + +def get_index_map(block: tir.Block, layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None: + """Get index maps for the block + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + layout : List[str] + the target layout index map to be used. + 'n' for [i, k] layout + 't' for [k, j] layout + 'a' for auto inference based on whether the last axis is reduction. + + Returns + ------- + index_maps : Optional[Tuple[tir.IndexMap]] + The index maps for the block, or None if the block is not a gemm-liked kernel + """ + if layout is None: + layout = ["n", "t", "n"] + traits = detect_iter_traits(block) + if traits is None: + return None + A_traits, B_traits, C_traits, block_traits = traits + + def get_ordered_axes(region: list[Range]) -> set[Var]: + axes: list[Var] = [] + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes.append(r.min) + return axes + + def is_common_reduce(var: Var) -> bool: + return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block.iter_vars) + + def has_common_reduce(var: Var) -> bool: + vars = collect_vars_from_expr(var) + return any(is_common_reduce(v) for v in vars) + + def check_last_trait(region: list[Range]): + axes = get_ordered_axes(region) + return has_common_reduce(axes[-1]) + + def infer_layout(layout: str, region: list[Range], kind: str = "A"): + """ + Infer the layout based on the region and the kind of buffer + kind: "A", "B", "C" + """ + primary_iter, secondary_iter, reduction_iter = { + "A": (IterKind.kIter_I, IterKind.kIter_K, IterKind.kIter_K), + "B": (IterKind.kIter_K, IterKind.kIter_J, IterKind.kIter_K), + "C": (IterKind.kIter_I, IterKind.kIter_J, None), + }[kind] + + spatial_iter = { + "A": IterKind.kIter_I, + "B": IterKind.kIter_J, + "C": None, + }[kind] + + if layout == "n": + return [IterKind.kIter_S, primary_iter, secondary_iter] + elif layout == "t": + return [IterKind.kIter_S, secondary_iter, primary_iter] + elif layout == "a": + # auto inference layout + # for buffer with reduction axis, we put it as the last axis + # otherwise, we put it as the first axis + if kind == "C": + return [IterKind.kIter_S, primary_iter, secondary_iter] + else: + return ( + [IterKind.kIter_S, spatial_iter, reduction_iter] + if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter] + ) + else: + raise ValueError(f"Unknown layout {layout}") + + A_index_map = make_iter_fusion_index_map(A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) + B_index_map = make_iter_fusion_index_map(B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) + C_index_map = make_iter_fusion_index_map(C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) + + matmul_index_map = make_iter_fusion_index_map( + block_traits, + [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K], + ) + + return ( + matmul_index_map, + A_index_map, + B_index_map, + C_index_map, + ) + + +def get_in_out_dtypes(block: tir.Block) -> tuple[str]: + """ + Detect In/Out data types for the given block based on the analysis if read/write buffers. + """ + assert len(block.reads) > 0 and len(block.writes) > 0 + in_dtype = block.reads[0].buffer.dtype + out_dtype = block.writes[0].buffer.dtype + return (in_dtype, out_dtype) + + +def get_dequantize_block(sch, blocks) -> BlockRV | None: + # check at least two input and one output + # at lease one input has uint dtype, and the output dtype is float + def is_dequantize(block: BlockRV) -> bool: + block_stmt = sch.get(block) + if len(block_stmt.reads) < 2: + return False + has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) + if not has_uint_input: + return False + return not (len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype)) + + dequantize_blocks = [block for block in blocks if is_dequantize(block)] + return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None + + +def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + if iter_types != {IterVar.DataPar}: + return False, False + if not isinstance(block_stmt.body, tir.BufferStore): + return False, False + if not isinstance(block_stmt.body.value, tir.BufferLoad): + return False, False + + def get_access_vars(region: list[Range]) -> list[Var]: + axes: list[Var] = [] + for r in region: + if not _is_one(r.extent): + return None + axes.extend(undefined_vars(r.min)) + # remove trivial axis + trivial_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) + axes = [axis for axis in axes if axis not in trivial_vars] + # remove duplicate axis + axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] + return axes + + lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] + rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] + is_identity = list(lhs_access_vars) == list(rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(rhs_access_vars) + return is_identity, is_transpose + + +def is_identity_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[0] + + +def is_transpose_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[1] + + +def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]): + result_blocks = [] + for block in blocks: + if not is_transpose_block(sch.get(block)): + result_blocks.append(block) + continue + try: + sch.compute_inline(block) + except Exception: + try: + sch.reverse_compute_inline(block) + except Exception: + result_blocks.append(block) + return result_blocks + + +def normalize_to_matmul(sch: tir.Schedule, main_block: BlockRV, layout: list[str] | None = None) -> tir.Schedule | None: + if layout is None: + layout = ["n", "t", "n"] + block_stmt = sch.get(main_block) + + # let layout be 'a' to auto inference the layout + index_maps = get_index_map(block_stmt, layout=layout) + if index_maps is None: + logger.debug("Cannot find the appropriate index map for tensorcore") + return None + + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # `skip_simplify` to avoid the bug in the 1x1 conv + block = sch.reindex(main_block, ("read", 0), skip_simplify=True) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1), skip_simplify=True) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0), skip_simplify=True) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True) + return sch + + +def get_tensorized_func_and_tags( + func: tir.PrimFunc, + target: Target, + layout: list[str] | None = None, + skip_normalize: bool = False, + allow_gemv: bool = False, +) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]: + """ + transform function to matmul if necessary (e.g. transform conv2d with im2col) + """ + if layout is None: + layout = ["a", "a", "a"] + # step1. detect whether the function can utilize tensorcore + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + if not reduction_blocks or len(reduction_blocks) != 1: + return func, None + + def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + conditions = [] + conditions.append(len(block_stmt.reads) == 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) > 0) + return all(conditions) + + # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) + def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool | dict: + tags: dict[str, list[int] | int] = {} + block_stmt = sch.get(block) + + # Nvidia Only Support Tensor Core for + # devices greater than 70. + if check_sm_version(target.arch) < 70: + return False + # analysis tensorcore axis + # todo(lei): maybe we can remove this in the future + (write_buffer_region,) = block_stmt.writes + out_axis = len(write_buffer_region.buffer.shape) + tags["tensorcore_config"] = [out_axis - 2, out_axis - 1] + + # analysis pipeline stage + # todo(lei): maybe we can integrate this into policy in the future + tags["pipeline_stage"] = 1 + if target.kind.name == "cuda" and check_sm_version(target.arch) in {80, 90}: + # enable pipeline stage only for sm_80 devices + tags["pipeline_stage"] = 2 + + # analysis async copy + # todo(lei): maybe we can integrate this into policy in the future + tags["use_async_copy"] = False + if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) in {80, 90}: + # async copy only works in software pipeline. + tags["use_async_copy"] = True + + # analysis intrin information + def get_ordered_axes(region: list[Range]) -> set[Var]: + axes: list[Var] = [] + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes.append(r.min) + return axes + + def is_common_reduce(var: Var) -> bool: + return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block_stmt.iter_vars) + + def has_common_reduce(var: Var) -> bool: + vars = collect_vars_from_expr(var) + return any(is_common_reduce(v) for v in vars) + + def check_last_trait(region: list[Range]): + axes = get_ordered_axes(region) + return has_common_reduce(axes[-1]) + + intrin_info: dict = {} + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + intrin_info["in_dtype"] = in_dtype + intrin_info["out_dtype"] = out_dtype + + if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32": + # INT32 Accum TensorCore only supports SM Version > 32. + return False + + # if the last dimension is reduce axis, the B is transposed + intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) + if func.attrs is not None and "input_transform_kind" in func.attrs: + intrin_info["input_transform_kind"] = func.attrs["input_transform_kind"] + if func.attrs is not None and "weight_transform_kind" in func.attrs: + intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"] + tags["intrin_info"] = intrin_info + # Analysis Block Reduction Optimization + # Currently, we only support block reduction depth 2 for small M + # When the func is a dequantize like ops, we should consider the M + require_block_reduce = False + # And we only support float16 for now + if hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]: + for arg in func.params: + inp_shape = func.buffer_map[arg].shape + M = inp_shape[0] + if isinstance(M, tir.IntImm) and M <= 128: + require_block_reduce = True + break + if require_block_reduce and check_sm_version(target.arch) == 80: + tags["block_reduction_depth"] = 2 + return tags + + (main_block,) = reduction_blocks + if _can_be_tensorized(sch, main_block) is None: + return func, None + + block_stmt = sch.get(main_block) + if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): + logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore") + return func, None + + # reindex and transform functions + # Normalize tensor functions to C[S, I, J] += A[S, I, K] * B[S, J, K] + # or C[S, I, J] += A[S, I, K] * B[S, K, J] + # skip normalize when we want to detect tags only. + if not skip_normalize: + sch = normalize_to_matmul(sch, main_block, layout) + if sch is None: + return func, None + + block_stmt = sch.get(main_block) + + # 16 for 16 bits tensor core while 32 for 8bits tensorcore. + minimal_tensorize_spatial_threshold = 16 + minimal_tensorize_reduce_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 + # the batch dimension is not taken into consideration. + for item_var in block_stmt.iter_vars[1:]: + extent = item_var.dom.extent + iter_type = item_var.iter_type + + if iter_type is IterVar.DataPar: + minimal_tensorize_threshold = minimal_tensorize_spatial_threshold + elif iter_type is IterVar.CommReduce: + minimal_tensorize_threshold = minimal_tensorize_reduce_threshold + else: + raise ValueError(f"Unknown IterVar type {iter_type}") + + if isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold: + return func, None + tags = analysis_tensorcore_tags(sch, main_block, target) + return sch.mod["main"], tags + + return func, None + + +def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): + from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, + ) + + assert dtype in [ + "bfloat16", + "float16", + "int8", + "float8_e4m3", + "float8_e5m2", + ], "Only support bfloat16, float16, int8, float8_e4m3, float8_e5m2" + # TODO(lei): actually should analyze based on bits instead of dtype + if dtype in ["bfloat16", "float16"]: + ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout + ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout + elif dtype in ["int8", "float8_e4m3", "float8_e5m2"]: + # int8 mma only support 32x16 to 16x32 layout + if matrix_name == "A" and trans is False: + ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a + elif matrix_name == "B" and trans is True: + ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_b + else: + raise ValueError("Unknown matrix name ", matrix_name) + + # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out + def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout(thread_id, local_id) + + def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout_trans(thread_id, local_id) + + def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + return ldmatrix_layout(thread_id, local_id) + + if dtype in ["bfloat16", "float16"]: + ldmatrix_index_map = ldmatrix_trans_permutation_16x16_32x8_16x16 if trans else ldmatrix_permutation_16x16_32x8_16x16 + else: + ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 + + ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype) + # TODO(lei): index_dtype should be analyzed from the schedule + row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32] + inversed_index_map = ldmatrix_index_map.inverse([row, col]) + return ldmatrix_index_map, inversed_index_map + + +# This function is used to get the index map for the stage3 of the +# Ladder weight propagation, which can be used to avoid the ldmatrix +# Instructions. +def get_ladder_stage3_map(dtype="float16", index_dtype="int32"): + def shared_32x8_to_mma_32x8_layout(i, j): + thread_id = (i % 8) * 4 + (j // 2) + local_id = (i // 8) * 2 + (j % 2) + return thread_id, local_id + + def shared_32x16_to_mma_32x16_layout(i, j): + thread_id = (i % 8) * 4 + (j // 4) + local_id = (i // 8) * 4 + (j % 4) + return thread_id, local_id + + assert dtype in [ + "bfloat16", + "float16", + "int8", + "float8_e4m3", + "float8_e5m2", + ], "Only support float16, int8, float8_e4m3, float8_e5m2" + if dtype in ["bfloat16", "float16"]: + stage3_layout = shared_32x8_to_mma_32x8_layout + elif dtype in ["int8", "float8_e4m3", "float8_e5m2"]: + stage3_layout = shared_32x16_to_mma_32x16_layout + else: + raise ValueError("Unknown dtype ", dtype) + + # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out + def ladder_stage3_permutation_16x16_32x8_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + new_thread_id, new_local_id = stage3_layout(thread_id, local_id) + new_kernel_i = (new_thread_id * 8 + new_local_id) // 16 + new_kernel_j = (new_thread_id * 8 + new_local_id) % 16 + return new_kernel_i, new_kernel_j + + def ladder_stage3_permutation_16x32_32x16_32x16_16x32(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + new_thread_id, new_local_id = stage3_layout(thread_id, local_id) + new_kernel_i = (new_thread_id * 16 + new_local_id) // 32 + new_kernel_j = (new_thread_id * 16 + new_local_id) % 32 + return new_kernel_i, new_kernel_j + + if dtype in ["bfloat16", "float16"]: + stage3_index_map = ladder_stage3_permutation_16x16_32x8_32x8_16x16 + else: + stage3_index_map = ladder_stage3_permutation_16x32_32x16_32x16_16x32 + + stage3_index_map = IndexMap.from_func(stage3_index_map, index_dtype=index_dtype) + # TODO(lei): index_dtype should be analyzed from the schedule + row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32] + inversed_index_map = stage3_index_map.inverse([row, col]) + return stage3_index_map, inversed_index_map + + +def layout_propagate_chain( + sch: tir.Schedule, + start_block: BlockRV, + start_buffer: tir.Buffer, + end_block: BlockRV, + index_map: IndexMap, +): + # some layout transformation may only apply to the last n dimensions + # propagate the layout transformation to the chain of blocks + block = start_block + buffer = start_buffer + index_map = index_map + while True: + last_buffer = buffer + producers = sch.get_producers(block) + if len(producers) == 0: + break + for producer in producers: + if len(sch.get(producer).writes) != 1: + return index_map + if sch.get(producer) == sch.get(end_block): + return index_map + (write,) = sch.get(producer).writes + + read = find_first_similar_region(sch.get(producer).reads, last_buffer) + if write.buffer == buffer: + block = producer + buffer = read.buffer + write_indices = [r.min for r in write.region] + read_indices = [r.min for r in read.region] + # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout + tmp_index_map = IndexMap(write_indices, read_indices, None) + tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0] + + # if dequantize like ops are used, the scaling factor should be considered + # to be applied to the final indices + scaling_factor = 1 + for i, j in zip(write.buffer.shape, read.buffer.shape): + scaling_factor *= i // j + final_indices = list(index_map.map_indices(tmp_index_map.map_indices(write_indices))) + final_indices[-1] = final_indices[-1] // scaling_factor + index_map = IndexMap( + write_indices, + final_indices, + None, + ) + if buffer == last_buffer: + break + return index_map diff --git a/tilelang/original/tilelang/carver/roller/__init__.py b/tilelang/original/tilelang/carver/roller/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06931ffe7c7c8df7761dc49def73245bd9670e7b --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/__init__.py @@ -0,0 +1,5 @@ +from .node import PrimFuncNode, OutputNode, Edge # noqa: F401 +from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401 +from .hint import Hint # noqa: F401 +from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 +from ..arch import TileDevice, CUDA # noqa: F401 diff --git a/tilelang/original/tilelang/carver/roller/bestfit.py b/tilelang/original/tilelang/carver/roller/bestfit.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7817429d8dd2fa79db1a33988086758556f19f --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/bestfit.py @@ -0,0 +1,62 @@ +"""Benefit For Carver Schedule""" + + +class Block: + def __init__(self, start, end, is_free): + self.start = start + self.end = end + self.is_free = is_free + + def size(self) -> int: + return self.end - self.start + + def merge(self, other): + assert self.is_free == other.is_free + self.start = min(self.start, other.start) + self.end = max(self.end, other.end) + + def __repr__(self) -> str: + return f"" + + +class BestFit: + def __init__(self, align=32): + self.limit = 0 + self.list = [] + self.align = align + + def malloc(self, size) -> Block: + size = (size + self.align - 1) // self.align * self.align + found = None + for block in self.list: + if block.is_free and block.size() >= size and (not found or found.size() > block.size()): + found = block + if found: + found.is_free = False + remain = found.size() - size + if remain != 0: + found.end -= remain + self.list.insert(self.list.index(found) + 1, Block(found.end, found.end + remain, True)) + return found + elif len(self.list) > 0 and self.list[-1].is_free: + add = size - self.list[-1].size() + self.list[-1].end += add + self.limit = self.list[-1].end + self.list[-1].is_free = False + return self.list[-1] + else: + block = Block(self.limit, self.limit + size, False) + self.list.append(block) + self.limit += size + return block + + def free(self, block: Block) -> None: + assert not block.is_free + idx = self.list.index(block) + self.list[idx] = Block(block.start, block.end, True) + if idx + 1 < len(self.list) and self.list[idx + 1].is_free: + self.list[idx].merge(self.list[idx + 1]) + self.list.pop(idx + 1) + if idx - 1 >= 0 and self.list[idx - 1].is_free: + self.list[idx].merge(self.list[idx - 1]) + self.list.pop(idx - 1) diff --git a/tilelang/original/tilelang/carver/roller/hint.py b/tilelang/original/tilelang/carver/roller/hint.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd1fb40652f855fe22ea1271da6af550549200e --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/hint.py @@ -0,0 +1,258 @@ +"""Hint definition for schedule""" + +from tvm import DataType +from . import PrimFuncNode +import numpy as np +from .rasterization import * + + +class TensorCoreExtraConfig: + """ + This class is used to store extra information for tensorcore + """ + + def __init__( + self, + AS_shape: tuple[int], + BS_shape: tuple[int], + AF_shape: tuple[int], + BF_shape: tuple[int], + tc_axis: tuple[int], + ) -> None: + self.AS_shape: tuple[int] = AS_shape + self.BS_shape: tuple[int] = BS_shape + self.AF_shape: tuple[int] = AF_shape + self.BF_shape: tuple[int] = BF_shape + self.tc_axis: tuple[int] = tc_axis + + +class Stride: + """ + Manages stride information for a given axis of a tensor. + """ + + def __init__(self, stride: int = 1, ax: int = -1) -> None: + # which axis to put stride on + self._ax: int = int(ax) + # the stride size of the axis + self._stride: int = int(stride) + + @property + def ax(self) -> int: + return self._ax + + @property + def stride(self) -> int: + return self._stride + + def compute_strides_from_shape(self, shape: list[int]) -> list[int]: + ndim = len(shape) + strides = [1 for _ in shape] + for i in range(ndim - 2, -1, -1): + if i == self.ax: + strides[i] = self.stride + else: + strides[i] = int(strides[i + 1] * shape[i + 1]) + return strides + + def compute_elements_from_shape(self, shape: list[int]) -> int: + original_shape = np.prod(shape) + if not self.is_valid(): + strided_elem = original_shape + else: + assert self.ax < len(shape) + strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride + assert strided_elem >= original_shape + return int(strided_elem) + + def is_valid(self) -> bool: + return self.ax >= 0 + + def __repr__(self) -> str: + return f"" + + +class TileDict: + """ + Manages tiling information and configurations for computational tasks. + """ + + def __init__(self, output_tile) -> None: + self.output_tile = output_tile + # schedule config + self.tile_map = {} + self.rstep_map = {} + self.cached_tensors_map = {} + self.output_strides_map = {} + self.tensor_strides_map = {} + + # analysis + self.traffic = -1 + self.smem_cost = -1 + self.block_per_SM = -1 + self.num_wave = -1 + self.grid_size = -1 + self.valid = True + + def get_tile(self, func) -> list[int]: + return self.tile_map[func] + + def get_rstep(self, node) -> dict[str, int]: + return self.rstep_map[node] + + def __hash__(self) -> int: + return hash(tuple(self.output_tile)) + + +class IntrinInfo: + """ + The information of tensorcore intrinsic related information + """ + + def __init__( + self, + in_dtype: str, + out_dtype: str, + trans_b: bool, + input_transform_kind: int = 0, + weight_transform_kind: int = 0, + ) -> None: + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.trans_a = False + self.trans_b = trans_b + self.input_transform_kind = input_transform_kind + self.weight_transform_kind = weight_transform_kind + + def __repr__(self) -> str: + return f"" + + def is_input_8bit(self) -> bool: + return DataType(self.in_dtype).bits == 8 + + @property + def smooth_a(self) -> bool: + return self.input_transform_kind >= 2 + + @property + def smooth_b(self) -> bool: + return self.weight_transform_kind >= 2 + + @property + def inter_transform_a(self) -> bool: + return self.input_transform_kind >= 1 + + @property + def inter_transform_b(self) -> bool: + return self.weight_transform_kind >= 1 + + +class Hint: + """ + Central configuration class for managing various parameters of computational tasks. + """ + + def __init__(self) -> None: + self.arch = None + self.use_tc = None # todo(lei): this should be renamed. + + # Special axes tiling info + self.block = [] + self.thread = [] + # Special axes for MFMA + self.warp = [] + # Reduce axes tiling info + self.rstep = [] + self.reduce_thread = [] + self.rasterization_plan = NoRasterization() + self.cached_tensors = [] + self.output_strides = {} + self.schedule_stages = None + # Config for block reduction + self.block_reduction_depth = None # type: int + + # TL Specific + # Split-K factor for SM waste optimization + self.split_k_factor: int = 1 + + # Experimental + self._raxis_order = [] + self._step = [] + self.vectorize: dict[str, int] = {} + self.pipeline_stage = 1 + self.use_async = False + self.opt_shapes: dict[str, int] = {} + self.intrin_info = IntrinInfo("float16", "float16", True) + self.shared_scope: str = "shared" + self.pass_context: dict = {} + + def to_dict(self) -> dict: + dic = {} + dic["block"] = self.block + if self.use_tc: + dic["warp"] = self.warp + else: + dic["thread"] = self.thread + dic["rstep"] = self.rstep + if np.prod(self.reduce_thread) > 1: + dic["reduce_thread"] = self.reduce_thread + if self.use_tc: + dic["use_tc"] = self.use_tc + if self.output_strides: + dic["strides"] = {} + for k, stride in self.output_strides.items(): + if stride.is_valid(): + dic["strides"][k] = stride + if len(dic["strides"]) == 0: + del dic["strides"] + if np.prod(self._step) > 1: + dic["step"] = self._step + if self._raxis_order != []: + dic["raxis_order"] = self._raxis_order + if self.vectorize != {}: + dic["vectorize"] = self.vectorize + if self.pipeline_stage != 1: + dic["pipeline_stage"] = self.pipeline_stage + if self.block_reduction_depth is not None: + dic["block_reduction_depth"] = self.block_reduction_depth + return dic + + @classmethod + def from_dict(cls, dic: dict) -> "Hint": + hint = cls() + for k, v in dic.items(): + setattr(hint, k, v) + return hint + + def tensorcore_legalization(self): + # only keep the last 2 axes for tensorcore + self.warp = self.warp[-2:] + self.block = self.block[-2:] + return self + + @property + def raxis_order(self) -> list[int]: + if self._raxis_order != []: + return self._raxis_order + return list(range(len(self.rstep))) + + @property + def step(self) -> list[int]: + if self._step != []: + return self._step + return [1 for _ in self.block] + + def __repr__(self) -> str: + return str(self.to_dict()) + + def complete_config(self, node: PrimFuncNode): + # analysis pass context, for int8 mma, we should merge static shared memory + merge_static_smem = False + # int32 and float32 accum may take too much shared memory + if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]: + merge_static_smem = True + # Always merge dynamic shared memory + if self.shared_scope == "shared.dyn": + merge_static_smem = True + self.pass_context = {"tir.merge_static_smem": merge_static_smem} + return self diff --git a/tilelang/original/tilelang/carver/roller/node.py b/tilelang/original/tilelang/carver/roller/node.py new file mode 100644 index 0000000000000000000000000000000000000000..3122c7b078ad982ffbc8bc99d62580ac6617474d --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/node.py @@ -0,0 +1,588 @@ +"""PrimFunc Wrapper and Block information Analaysis""" + +from __future__ import annotations + +import tvm +from tvm import tir +from tvm.tir import IterVar, PrimFunc +from typing import Any +from tvm.tir.schedule.schedule import BlockRV +import numpy as np +import functools +from ..analysis import BlockInfo, get_reduction_blocks +from .. import analysis +from .. import normalize_prim_func +from .shape_inference import get_analyzer_by_tir +from dataclasses import dataclass + + +def pre_order_traverse(block_analyzer, blocks, func): + visited = set() + + def _traverse(block): + if block in visited: + return + visited.add(block) + for dep_blocks in block_analyzer.get_consumer_blocks(block): + _traverse(dep_blocks) + func(block) + + for block in blocks: + _traverse(block) + + +class BlockAnalyzer: + def __init__(self, sch) -> None: + self.sch: tir.Schedule = sch + self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch) + + def get_block_name(self, block: BlockRV) -> str: + return self.sch.get(block).name_hint + + def get_block_info(self, block: BlockRV) -> BlockInfo: + for block_info in self.block_infos: + if self.get_block_name(block) == block_info.name: + return block_info + return None + + def get_spatial_axis(self, block: BlockRV) -> list[IterVar]: + block_info = self.get_block_info(block) + axis = [] + for iter in block_info.iters: + if iter.kind == "S": + axis.append(iter) + return axis + + def get_reduce_axis(self, block: BlockRV) -> list[IterVar]: + block_info = self.get_block_info(block) + raxis = [] + for iter in block_info.iters: + if iter.kind == "R": + raxis.append(iter) + return raxis + + def get_input_buffers(self, block: BlockRV) -> list[tir.Buffer]: + buffers = [] + for read in self.sch.get(block).reads: + buffers.append(read.buffer) + return buffers + + def get_output_buffers(self, block: BlockRV) -> list[tir.Buffer]: + buffers = [] + for write in self.sch.get(block).writes: + buffers.append(write.buffer) + return buffers + + def get_buffers(self, block: BlockRV) -> list[tir.Buffer]: + return self.get_input_buffers(block) + self.get_output_buffers(block) + + def get_producer_blocks(self, block: BlockRV) -> list[BlockRV]: + return self.sch.get_producers(block) + + def get_consumer_blocks(self, block: BlockRV) -> list[BlockRV]: + return self.sch.get_consumers(block) + + +@dataclass +class Edge: + src_node: Node + dst_node: Node + src_id: int + dst_id: int + + +class Node: + def __init__(self, tags: dict | None = None, name: str = "Node") -> None: + self.name = name + if tags is None: + tags = {} + self._out_edges = [] + self._in_edges = [] + self._shapes = [] + self._dtypes = [] + self._tag: dict = {} + self.update_tags(tags) + + def update_tags(self, tags: dict) -> None: + for tag in tags: + self.add_tag(tag, tags[tag]) + + def set_tag(self, k: str, v: Any = True) -> None: + self.add_tag(k, v) + + def add_tag(self, k: str, v: Any = True) -> None: + self._tag[k] = v + + def get_tag(self, k: str) -> Any: + if k not in self._tag: + return None + return self._tag[k] + + def is_placeholder(self): + return False + + def is_output(self): + return False + + @property + def inputs(self) -> list[Edge]: + return self._in_edges + + @property + def outputs(self) -> list[Edge]: + return self._out_edges + + def set_inputs(self, i: int, edge: Edge): + assert i < len(self._in_edges) + self._in_edges[i] = edge + + def set_outputs(self, i: int, edge: Edge): + assert i < len(self._out_edges) + self._out_edges[i] = edge + + def get_dtype(self, id=0) -> tvm.DataType: + return self._dtypes[id] + + def set_dtype(self, dtype: tvm.DataType, id=0) -> None: + assert isinstance(dtype, tvm.DataType), type(dtype) + if dtype == tvm.DataType("bool"): + dtype = tvm.DataType("int8") + if len(self._dtypes) <= id: + self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)]) + elif self._dtypes[id] is not None: + assert self._dtypes[id] == dtype, (self._dtypes, dtype) + self._dtypes[id] = dtype + + def get_shape(self, id: int = 0) -> list[int]: + return self._shapes[id] + + def set_shape(self, shape: list[int], id=0, overwrite=False) -> None: + if len(self._shapes) <= id: + self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) + # elif self._shapes[id] is not None and not overwrite: + # assert self._shapes[id] == list(map(int, shape)), (self._shapes, list(map(int, shape))) + self._shapes[id] = list(map(int, shape)) + + def num_outputs(self) -> int: + if len(self.outputs) == 0: + return 0 + return max([e.src_id for e in self.outputs]) + 1 + + def get_ir(self) -> str: + raise NotImplementedError() + + def __repr__(self) -> str: + return "" + + +class PlaceHolderNode(Node): + def __init__(self, name=""): + super().__init__(name="PlaceHolder_" + name) + + def is_placeholder(self): + return True + + def get_ir(self) -> str: + return "placeholder" + + +class PrimFuncNode(Node): + def __init__(self, prim_func: PrimFunc, tags: dict | None = None, name: str = "PrimFuncNode") -> None: + super().__init__(tags, name=name) + self.prim_func = self._specialize_func(prim_func) + self.sch: tir.Schedule = tir.Schedule(self.prim_func) + self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch) + self.schedule_stages: list[BlockRV] = [] + self.blocks: list[BlockRV] = [] + self.output_blocks: list[BlockRV] = None + self.reduction_block: BlockRV = None + self.raxis = [] + self.input_buffers = [] + self.output_buffers = [] + self.buffers = [] + self.args = [] + self._analysis_funcinfo() + self._assign_placeholder_node() + self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks) + + # set input shapes and dtypes + for edge, buffer in zip(self.inputs, self.input_buffers): + edge.src_node.set_shape(buffer.shape, edge.src_id) + edge.src_node.set_dtype(tvm.DataType(buffer.dtype), edge.src_id) + for output_id, buffer in enumerate(self.output_buffers): + self.set_shape(buffer.shape, output_id) + self.set_dtype(tvm.DataType(buffer.dtype), output_id) + + def _assign_placeholder_node(self): + inputs: list[Node] = [] + for buffer in self.input_buffers: + inputs.append(PlaceHolderNode(buffer.name)) + + for dst_id, n in enumerate(inputs): + if isinstance(n, Node): + n = (n, 0) + assert len(n) == 2 + src_node, src_id = n[0], n[1] + edge = Edge(src_node, self, src_id, dst_id) + self._in_edges.append(edge) + src_node._out_edges.append(edge) + + def _specialize_func(self, func: PrimFunc): + # Specialize the function to make it more friendly for analysis. + # set attrs + for k, v in func.attrs.items(): + self.set_tag(k, v) + if self.get_tag("is_speclized"): + return func + opt_shapes = self.get_tag("opt_shapes") + if opt_shapes: + for name, shape in opt_shapes.items(): + var = analysis.find_var_from_func(func, name) + if var is not None: + func = func.specialize({var: shape.astype(var.dtype)}) + return func + + def _analysis_funcinfo(self): + root_block = analysis.get_root_block(self.sch) + blocks = self.sch.get_child_blocks(root_block) + self.blocks = blocks + + self.output_blocks = self.sch.get_output_blocks(root_block) + reduction_blocks = get_reduction_blocks(self.sch, blocks) + if reduction_blocks is None: + self.reduction_block = None + self.schedule_stages.append(*self.output_blocks) + else: + # analysis on the last reduction block + self.reduction_block = reduction_blocks[-1] + # set raxis + reduce_block_info = self.block_analyzer.get_block_info(self.reduction_block) + for iter in reduce_block_info.iters: + if iter.kind == "R": + self.raxis.append(iter) + self.schedule_stages.append(self.reduction_block) + + # collect output buffers + for output_block in self.output_blocks: + for write in self.sch.get(output_block).writes: + if write not in self.output_buffers: + self.output_buffers.append(write.buffer) + + for param in self.prim_func.params: + if param not in self.prim_func.buffer_map: + # in case of dynamic symbolic may in params + continue + buffer = self.prim_func.buffer_map[param] + if buffer not in self.output_buffers: + self.input_buffers.append(buffer) + + self.args = self.input_buffers + self.output_buffers + self.buffers = [buffer for buffer in self.prim_func.buffer_map.values()] + + # set dtype + self.set_dtype(tvm.DataType(self.output_buffers[0].dtype)) + + def get_opt_shape(self, name) -> int: + opt_shapes = self.get_tag("opt_shapes") + if opt_shapes is None: + return None + return opt_shapes[name] + + def extent_wrapper(self, value) -> int: + if isinstance(value, tvm.tir.Var): + return self.get_opt_shape(value.name) + elif isinstance(value, tvm.tir.IntImm): + return int(value) + else: + return value + + @functools.lru_cache + def get_space_dim(self) -> list[int]: + dim_size = [] + if self.reduction_block: + block_info = self.block_analyzer.get_block_info(self.reduction_block) + for iter in block_info.iters: + if iter.kind == "S": + if isinstance(iter.dom.extent, tvm.tir.IntImm): + dim_size.append(int(iter.dom.extent)) + else: + assert isinstance(iter.dom.extent, tvm.tir.Var) + dim_size.append(self.get_opt_shape(iter.dom.extent.name)) + else: + # assume outer stage has the same shape + loops = self.sch.get_loops(self.schedule_stages[0]) + for loop in loops: + dim_size.append(int(self.sch.get(loop).extent)) + return [int(x) for x in dim_size] + + def set_dtype(self, dtype: tvm.DataType, id=0) -> None: + assert isinstance(dtype, tvm.DataType), type(dtype) + if dtype == tvm.DataType("bool"): + dtype = tvm.DataType("int8") + if len(self._dtypes) <= id: + self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)]) + elif self._dtypes[id] is not None: + assert self._dtypes[id] == dtype, (self._dtypes, dtype) + self._dtypes[id] = dtype + + def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: + return tvm.DataType(buffer.dtype) + + def propagate(self, tile, rstep: dict | None = None, targets=None): + if rstep is None: + rstep = {} + shape = { + self.block_analyzer.get_output_buffers(block)[0].name: [tvm.arith.ConstIntBound(0, val - 1) for val in tile] + for block in self.schedule_stages + } + return self.ana.infer(shape, rstep, targets) + + def propagate_inputs(self, tile, rstep: dict | None = None) -> list[list[int]]: + if rstep is None: + rstep = {} + read_idx_offset = len(self.input_buffers) + targets = [t.name for t in self.args[:read_idx_offset]] + shapes, intermediate_bind = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(self.args[:read_idx_offset]): + if arg.name in intermediate_bind: + results.append(shapes[arg.name]) + continue + # should not exceed original shape + trimmed_shape = [self.extent_wrapper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))] + results.append(trimmed_shape) + return results + + # Propagate inputs only on reduction block + def propagate_inputs_on_reduction(self, tile, rstep: dict | None = None) -> list[list[int]]: + if rstep is None: + rstep = {} + reduction_block = self.reduction_block + args = self.block_analyzer.get_input_buffers(reduction_block) + targets = [t.name for t in args] + shapes, intermediate_bind = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(args): + if arg.name in intermediate_bind: + results.append(shapes[arg.name]) + continue + # should not exceed original shape + propagate_shape = shapes[arg.name] + buffer_shape = args[i].shape + if len(buffer_shape) > len(propagate_shape): + buffer_shape = buffer_shape[-len(propagate_shape) :] + trimmed_shape = [self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))] + results.append(trimmed_shape) + return results + + def propagate_outputs(self, tile, rstep: dict | None = None) -> list[list[int]]: + if rstep is None: + rstep = {} + read_idx_offset = len(self.input_buffers) + targets = [t.name for t in self.args[read_idx_offset:]] + shapes, _ = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(self.args[read_idx_offset:]): + # should not exceed original shape + trimmed_shape = list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) + results.append(trimmed_shape) + return results + + def propagate_reduction_inputs(self, shape, rstep: dict | None = None) -> dict[str, list[int]]: + if rstep is None: + rstep = {} + if self.reduction_block is None: + return {} + targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)] + results, _ = self.propagate(shape, rstep, targets) + return results + + def get_reduce_inputs_dtype(self): + if self.reduction_block is None: + return {} + return {b.name: tvm.DataType(b.dtype) for b in self.block_analyzer.get_input_buffers(self.reduction_block)} + + @functools.lru_cache + def infer_tensorcore_axis(self) -> tuple[int]: + # axis is fixed for one expression, so only inference and cached + assert self.get_tag("tensorcore_config") + + C_ax_m, C_ax_n = self.get_tag("tensorcore_config") + wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok + + output_buffer_shape = self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape + valid_region = [] + for region in output_buffer_shape: + if region.value == 1: + continue + valid_region.append(region) + + num_nvalid_regions = len(output_buffer_shape) - len(valid_region) + self.set_tag("num_nvalid_regions", num_nvalid_regions) + + def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): + spatial_dim = self.get_space_dim() + assert len(valid_region) == len(spatial_dim), f" {valid_region} mismatch with {spatial_dim}" + cl_shapes = [1] * len(spatial_dim) + cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m + cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n + return cl_shapes + + CL_shape = get_cl_shapes(C_ax_m, C_ax_n, num_nvalid_regions) + self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [C_ax_m, C_ax_n]]) + shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis}) + A_deps, B_deps = shapes.values() + A_ax_m = A_deps.index(wmma_m) + B_ax_n = B_deps.index(wmma_n) + + CL_shape = [1] * len(self.get_space_dim()) + shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis}) + A_deps, B_deps = shapes.values() + A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k) + B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k) + tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) + return tc_axis + + def footprint(self, shape, rstep, stride_map: dict | None = None) -> int: + if stride_map is None: + stride_map = {} + result = 0 + shapes, _ = self.propagate(shape, rstep) + + def is_broadcast_pattern(buffer, output_buffer): + return ( + buffer in self.args + and len(shapes[output_buffer.name]) > len(shapes[buffer.name]) + and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]) + ) + + def is_after_reduce_stage(block): + if not self.reduction_block: + return False + reduce_dependent_blocks = getattr(self, "reduce_dependent_blocks", None) + if reduce_dependent_blocks is None: + reduce_dependent_blocks = set() + pre_order_traverse( + self.block_analyzer, + [self.reduction_block], + lambda block: reduce_dependent_blocks.add(block), + ) + self.reduce_dependent_blocks = reduce_dependent_blocks + return block not in reduce_dependent_blocks + + # compute cached stages + cached_tensor = [] + for block in self.blocks: + output_buffer = self.block_analyzer.get_output_buffers(block)[0] + for buffer in self.block_analyzer.get_input_buffers(block): + cache = buffer.name not in cached_tensor and ( + is_broadcast_pattern(buffer, output_buffer) or self.block_analyzer.get_block_info(block).is_reduction() + ) + if not cache: + continue + cached_tensor.append(buffer.name) + if is_after_reduce_stage(block): + continue # cache after reduce op can often reuse buffer in reduce stage + + if buffer.name in stride_map: + num_elem = stride_map[buffer.name].compute_elements_from_shape(shapes[buffer.name]) + else: + num_elem = np.prod(shapes[buffer.name]) + buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) + buffer_len = (buffer_len + 31) // 32 * 32 + result += buffer_len + return result, cached_tensor + + def get_input_buffers(self) -> list[tir.Buffer]: + return self.block_analyzer.input_buffers + + +class OutputNode(Node): + def __init__(self, node, id=0): + super().__init__(name="OutputNode") + # connect node and output node + assert isinstance(node, PrimFuncNode), "OutputNode should connect to PrimFuncNode" + + # initialize edge and connect + src_node, src_id = node, id + edge = Edge(src_node, self, src_id, 0) + self._in_edges.append(edge) + src_node._out_edges.append(edge) + + self.set_shape(node.get_shape(id)) + self.set_dtype(node.get_dtype(id)) + + def is_output(self): + return True + + def get_ir(self) -> str: + return "output" + + +def topo_order(list_of_nodes) -> list[Node]: + input_ready_count = {node: len(node.inputs) for node in list_of_nodes} + ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes)) + output_list = [] + while len(ready) > 0: + node = ready.pop(0) + output_list.append(node) + for edge in node.outputs: + dst_node = edge.dst_node + if dst_node not in input_ready_count: + input_ready_count[dst_node] = len(dst_node.inputs) + list_of_nodes.append(dst_node) + input_ready_count[dst_node] -= 1 + assert input_ready_count[dst_node] >= 0 + if input_ready_count[dst_node] == 0: + ready.append(dst_node) + assert len(list_of_nodes) == len(output_list) + return output_list + + +def find_topo_sort_priority(output_node_list) -> list[Node]: + import sys + + sys.setrecursionlimit(10000) + + def topo_sort_get_layer(node, topo_layer): + if node in topo_layer: + return + topo_layer[node] = 0 + for edge in node.inputs: + topo_sort_get_layer(edge.src_node, topo_layer) + topo_layer[node] = max(topo_layer[node], topo_layer[edge.src_node] + 1) + + topo_layer = {} + for node in output_node_list: + topo_sort_get_layer(node, topo_layer) + + def topo_sort_dfs(node, visited, topo_order): + if node in visited: + return + visited.add(node) + ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], key=lambda n: topo_layer[n], reverse=True) + for n in ordered_input_nodes: + topo_sort_dfs(n, visited, topo_order) + topo_order.append(node) + + visited = set() + topo_order = [] + for node in output_node_list: + topo_sort_dfs(node, visited, topo_order) + return topo_order + + +def find_topo_sort(output_node_list) -> list[Node]: + def topo_sort_dfs(node, visited, topo_order): + if node in visited: + return + visited.add(node) + for edge in node.inputs: + topo_sort_dfs(edge.src_node, visited, topo_order) + topo_order.append(node) + + visited = set() + topo_order = [] + for node in output_node_list: + topo_sort_dfs(node, visited, topo_order) + return topo_order diff --git a/tilelang/original/tilelang/carver/roller/policy/__init__.py b/tilelang/original/tilelang/carver/roller/policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b39a27bdb5957e4f0ea1b233b19b88db7095b1bc --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/policy/__init__.py @@ -0,0 +1,2 @@ +from .default import DefaultPolicy # noqa: F401 +from .tensorcore import TensorCorePolicy # noqa: F401 diff --git a/tilelang/original/tilelang/carver/roller/policy/common.py b/tilelang/original/tilelang/carver/roller/policy/common.py new file mode 100644 index 0000000000000000000000000000000000000000..fb33eefdb7565528e7707811bb9d07dc5c7f9c7a --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/policy/common.py @@ -0,0 +1,52 @@ +import numpy as np + + +def get_all_factors(n: int) -> list[int]: + # Calculate the square root of n and round it up to the nearest integer + n0 = int(np.ceil(np.sqrt(n))) + + # Find all divisors of n that are less than n0 + val = np.where(n % np.arange(1, n0) == 0)[0] + 1 + + # If n is a perfect square, add the square root to the list of factors + mid = np.array([], dtype=int) if n0 * n0 != n else [n0] + + # Combine the factors and their corresponding larger pair factors + return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])] + + +def factorize(n: int) -> list[int]: + i = 2 # Start with the smallest prime number + result = [] + + # Iterate through numbers to find factors + while n > 1: + if n % i == 0: # If i is a factor of n + n //= i # Divide n by i and keep the integer part + result.append(i) + else: + i += 1 # Try the next number + return result + + +def coalesced_factor(subtensor: list[int], tensor: list[int]) -> int: + # If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension + if subtensor[-1] != tensor[-1] or len(subtensor) == 1: + return subtensor[-1] + else: + # Recursively calculate the coalesced factor for the remaining dimensions + return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1]) + + +def coalesced_tensor_shape(subtensor: list[int], tensor: list[int], transaction_size: int) -> int: + # Calculate the total number of elements in the subtensor + bytes = int(np.prod(subtensor)) + + if bytes == 0: + return 0 + + # Calculate the coalesced factor for the subtensor + factor = int(coalesced_factor(subtensor, tensor)) + + # Compute the shape of the coalesced tensor + return transaction_size * bytes / min(transaction_size, factor) diff --git a/tilelang/original/tilelang/carver/roller/policy/default.py b/tilelang/original/tilelang/carver/roller/policy/default.py new file mode 100644 index 0000000000000000000000000000000000000000..d09216e1ceb4ce7eb02a4643d18fac4990a01a34 --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/policy/default.py @@ -0,0 +1,816 @@ +"""Policy for cuda core schedule""" + +from __future__ import annotations +import functools +import math +from queue import PriorityQueue +from collections.abc import Iterable + +import numpy as np +import tvm + +from ...arch import TileDevice +from ..bestfit import BestFit +from ..hint import Hint, Stride, TileDict +from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors +from ..node import PrimFuncNode, OutputNode, find_topo_sort +from ..rasterization import NoRasterization + + +class DefaultPolicy: + """ + Default Policy for fastdlight, a heuristic plan that tries to + minimize memory traffic and maximize parallelism.for BitBLAS Schedule. + """ + + func: tvm.tir.PrimFunc + nodes: list[PrimFuncNode] = [] + arch: TileDevice + tags: dict + + def __init__(self, arch: TileDevice, tags: dict | None = None) -> None: + if tags is None: + tags = {} + + self.arch = arch + self.tags = tags + self.rasterization = NoRasterization() + + @classmethod + def from_prim_func(cls, func: tvm.tir.PrimFunc, arch: TileDevice, tags: dict | None = None, name: str = "PrimFuncNode"): + return cls(arch, tags)._init_with_prim_func(func, name) + + @classmethod + def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None): + return cls(arch, tags)._init_with_output_nodes(nodes) + + def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str = "PrimFuncNode") -> DefaultPolicy: + if func is not None and isinstance(func, tvm.tir.PrimFunc): + self.func = func + self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name) + else: + raise NotImplementedError("Only support PrimFunc for now") + output_nodes = [OutputNode(self.prim_func_node)] + self._init_with_output_nodes(output_nodes) + return self + + def _init_with_output_nodes(self, output_nodes: list[OutputNode]): + self.ordered_nodes = list(filter(lambda n: not n.is_placeholder() and not n.is_output(), find_topo_sort(output_nodes))) + for node in self.ordered_nodes: + node.update_tags(self.tags) + + self.output_nodes = [] + for node in self.ordered_nodes: + is_topo_output = True + for edge in node.outputs: + if not edge.dst_node.is_output(): + is_topo_output = False + if is_topo_output: + self.output_nodes.append(node) + return self + + def emit_config(self, topk: int) -> list[Hint]: + base_tile = self.get_base_tile() + if base_tile is None: + return [] + + rstep_map = {node: self._assign_reduce_step(node) for node in self.ordered_nodes} + smem_tile_condidates = self.dfs_smem_tile(base_tile, rstep_map) + results = [] + for td in smem_tile_condidates: + if not self.check_tile_shape_isvalid(td): + continue + + self._expand_reduce_axis(td) + for codegen_dicts in self.assign_block_size(td): + if isinstance(codegen_dicts, dict) and len(codegen_dicts) == 1: + results.append(list(codegen_dicts.values())[0]) + else: + results.append(codegen_dicts) + if len(results) >= topk: + break + if len(results) >= topk: + break + return results + + def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: + _steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()] + steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)] + for i in range(len(steps)): + added = list( + filter( + lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], + [2, 4, 8, 16, 32], + ) + ) + steps[i].extend(added) + steps[i] = sorted(steps[i]) + visited_tiles = {} + queue = PriorityQueue() + + def prio(td: TileDict): + return (td.traffic + 1) * td.num_wave + + def add_to_queue(tile): + if tuple(tile) in visited_tiles: + return + td = self.compute_tile_dict(tile, rstep_map) + visited_tiles[tuple(tile)] = td + if td.valid: + queue.put([prio(td), tile]) + + add_to_queue(init_tile) + while not (queue.empty() or len(visited_tiles) > 2000): + _, tile = queue.get() + dim_ids = [step.index(t) for step, t in zip(steps, tile)] + for i in reversed(range(len(dim_ids))): + if dim_ids[i] + 1 < len(steps[i]): + new_tile = tile.copy() + new_tile[i] = steps[i][dim_ids[i] + 1] + add_to_queue(new_tile) + + visited_tiles = filter(lambda td: td.valid, visited_tiles.values()) + sorted_tiles = sorted(visited_tiles, key=lambda td: prio(td)) + return sorted_tiles + + def get_base_tile(self): + """ + Gets the minimum tile configuration that satisfies no redundancy in computation. + + Returns + ------- + List[int] + The base tile configuration, which is a list of 1s equal in length to the space dimensions + of the primary function node. + """ + if len(set([len(node.get_space_dim()) for node in self.output_nodes])) > 1: + # If output dim sizes are not same, don't know how to handle them + return None + + out_node = self.output_nodes[0] + shape = out_node.get_space_dim() + base_tile = [1 for _ in shape] + wpi = self.compute_workload_per_item(base_tile) + for dim, n in enumerate(shape): + factors = [n] + for factor in factors: + if factor == base_tile[dim]: + continue + tile = base_tile.copy() + tile[dim] = factor + new_wpi = self.compute_workload_per_item(tile) + if new_wpi < wpi: + wpi, base_tile = new_wpi, tile + else: + break + + return base_tile + + # handles multiple output cases + def _get_output_tile_map(self, tile): + """ + Handles multiple output cases by mapping output nodes to their respective tile configurations. + + Parameters + ---------- + tile : List[int] + The tile configuration. + + Returns + ------- + Dict + A dictionary mapping the primary function node to its corresponding tile configuration + based on the output nodes' space dimensions. + """ + tile_map = {} + for node in self.output_nodes: + tile_map[node] = [tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))] + return tile_map + + def compute_workload_per_item(self, output_tile) -> float: + op_tile_map = self._get_output_tile_map(output_tile) + compute = 0 + num_item = int(np.prod(output_tile)) + for node in reversed(self.ordered_nodes): + tile = op_tile_map[node] + dep = node.propagate_inputs(tile) + compute += int(np.prod(tile)) + for i, edge in enumerate(node.inputs): + op_tile_map[edge.src_node] = dep[i] + return float(compute / num_item) + + def score_block_size(self, n): + """ + Scores a block size based on its efficiency and fit relative to the architecture's warp size and SM partition. + + Parameters + ---------- + n : int + The block size to score. + + Returns + ------- + Tuple[float, float] + A tuple containing two scores representing efficiency and fit, respectively. + """ + num_wrap = (n + self.arch.warp_size - 1) // self.arch.warp_size + r1 = max(num_wrap / self.arch.sm_partition, self.arch.sm_partition / num_wrap) + r2 = (num_wrap * self.arch.warp_size - n) / n + return (r1, r2) + + def get_block_size(self, n): + """ + Determines the optimal block size for a given constraint, based on scoring various factors. + + Parameters + ---------- + n : int + The constraint size. + + Returns + ------- + int + The optimal block size chosen from the factors of n, constrained by a maximum of 1024 and + scored by the `score_block_size` method. + """ + factors = get_all_factors(n) + factors = list(filter(lambda x: x <= 1024, factors)) + factor_ordered = sorted(factors, key=self.score_block_size) + return factor_ordered[0] + + def get_node_reduce_step_candidates(self, node: PrimFuncNode): + """ + Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2. + + Parameters + ---------- + node : PrimFuncNode + The node for which to calculate reduction step candidates. It contains reduction axes (raxis) + with their domains (dom.extent). + + Returns + ------- + Dict[str, List[int]] + A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, + this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 + as step candidates; for others, it uses all factors of the domain. + """ + + results = {} + for k_iter in node.raxis: + all_factors = get_all_factors(int(k_iter.dom.extent)) + if len(all_factors) == 2 and int(k_iter.dom.extent) > 64: + all_factors = [1] + while all_factors[-1] * 2 < int(k_iter.dom.extent): + all_factors.append(all_factors[-1] * 2) + results[k_iter.var.name] = all_factors + return results + + def _assign_reduce_step(self, node: PrimFuncNode): + """ + Assigns an optimal reduction step for the given PrimFuncNode. + + Parameters + ---------- + node : PrimFuncNode + The node for which the reduction step is to be assigned. + + Returns + ------- + Dict + A dictionary mapping reduction axis variable names to their optimal reduction steps. + """ + if node.reduction_block is None: + return {} + + raxis = node.raxis + tile = [1] * len(node.get_space_dim()) + all_steps = self.get_node_reduce_step_candidates(node) + + def sim(a: int, b: int): + return (2 * a * b) / (a * a + b * b) + + def _score(rstep_id): + rstep = {k: all_steps[k][rstep_id[k]] for k in rstep_id} + score = 0 + shape = node.propagate_inputs(tile, rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + read_transaction_elements = self.arch.transaction_size[1] // ((node.get_buffer_dtype(input_buffer).bits + 7) // 8) + score += sim( + int(coalesced_factor(shape[i], input_buffer.shape)), + read_transaction_elements, + ) + return score + + def _enlarge(rstep_id): + candidates = [] + candidates.append((rstep_id, _score(rstep_id))) + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + best = max(candidates, key=lambda x: x[1]) + return best + + # enlarge rstep to ensure read is coaleased + cur_rstep_id = {ax.var.name: 0 for ax in raxis} + cur_score = _score(cur_rstep_id) + while True: + if cur_score == 0: + break + new_rstep, new_score = _enlarge(cur_rstep_id) + if new_score <= cur_score: + break + else: + cur_rstep_id, cur_score = new_rstep, new_score + rstep = {k: all_steps[k][cur_rstep_id[k]] for k in cur_rstep_id} + return rstep + + def _expand_reduce_axis(self, td: TileDict): + """ + Expands the reduction axis in the TileDict based on shared memory limits. + + Parameters + ---------- + td : TileDict + The TileDict object to be optimized. + + Returns + ------- + None + This function modifies the TileDict in place. + """ + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + + def _score(rstep_id): + rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis} + score = 0 + shape = node.propagate_inputs(td.get_tile(node), rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis} + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map[node] = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis} + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = self._compute_shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis} + return rstep + + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map[node]) + rstep_map[node] = rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + + def _compute_memory_traffic(self, output_tile): + """ + Computes the memory traffic for a given output tile configuration. + + Parameters + ---------- + output_tile : List[int] + The output tile configuration. + + Returns + ------- + Tuple[int, Dict] + The total memory traffic and a map of operation tiles. + """ + op_tile_map = self._get_output_tile_map(output_tile) + traffic = 0 + for node in reversed(self.ordered_nodes): + tile = op_tile_map[node] + input_shapes = node.propagate_inputs(tile) + output_shapes = node.propagate_outputs(tile) + for i, edge in enumerate(node.inputs): + op_tile_map[edge.src_node] = input_shapes[i] + if edge.src_node.is_placeholder(): + nbytes = (edge.src_node.get_dtype().bits + 7) // 8 + read_transaction_elements = self.arch.transaction_size[1] // nbytes + traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), read_transaction_elements) * nbytes + for edge in node.outputs: + if edge.dst_node.is_output(): + nbytes = (edge.src_node.get_dtype().bits + 7) // 8 + write_transaction_elements = self.arch.transaction_size[0] // nbytes + traffic += ( + coalesced_tensor_shape(output_shapes[edge.src_id], node.get_shape(edge.src_id), write_transaction_elements) * nbytes + ) + + return traffic, op_tile_map + + def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): + """ + Infers the shared memory usage of a node given a TileDict configuration. + + Parameters + ---------- + td : TileDict + The TileDict object containing the tile configuration. + node : PrimFuncNode + The node for which to infer the shared memory usage. + + Returns + ------- + int + The estimated amount of shared memory used by the node. + """ + return node.footprint(td.get_tile(node), td.get_rstep(node), td.tensor_strides_map[node]) + + def _compute_shared_memory_usage(self, td: TileDict): + """ + Computes the stride map for a given node and TileDict configuration. + + Parameters + ---------- + node : PrimFuncNode + The node for which to compute the stride map. + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + Tuple[Dict, Dict] + The output strides and tensor strides. + """ + self._compute_stride_map(td) + allocator = BestFit() + block_map = {} + processed = set() + cached_tensors_map = {} + + def can_free(node, out_id): + return all(not (edge.src_id == out_id and edge.dst_node not in processed) for edge in node.outputs) + + for node in self.ordered_nodes: + node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node) + block = allocator.malloc(node_internal_bytes) + allocator.free(block) + # free inputs + processed.add(node) + for edge in node.inputs: + if not edge.src_node.is_placeholder() and can_free(edge.src_node, edge.src_id): + allocator.free(block_map.pop((edge.src_node, edge.src_id))) + # alloc outputs + for edge in node.outputs: + if not edge.dst_node.is_output() and (node, edge.src_id) not in block_map: + dtype_bytes = (node.get_dtype(edge.src_id).bits + 7) // 8 + stride = td.output_strides_map[node][len(node.inputs) + edge.src_id] + output_elem = stride.compute_elements_from_shape(td.get_tile(node)) + block_map[(node, edge.src_id)] = allocator.malloc(output_elem * dtype_bytes) + + assert len(block_map) == 0 + return allocator.limit, cached_tensors_map + + def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): + """ + Computes the stride map for a given node based on the TileDict configuration. + + Parameters + ---------- + node : PrimFuncNode + The node for which to compute the stride map. + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + Tuple[Dict, Dict] + A tuple of dictionaries containing the output strides and tensor strides. + """ + output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)} + tensor_strides = {} + return output_strides, tensor_strides + + def _compute_stride_map(self, td: TileDict): + """ + Computes the stride map for all nodes in a TileDict. + + Parameters + ---------- + td : TileDict + The TileDict object for which to compute the stride maps. + + Returns + ------- + None + This function updates the TileDict object in place with the computed stride maps. + """ + output_strides_map = {} + tensor_strides_map = {} + for node in self.ordered_nodes: + output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(node, td) + td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map + + def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict: + """ + Computes and returns a TileDict object for a given output tile configuration and reduction step map. + + Parameters + ---------- + output_tile : List[int] + The output tile configuration. + rstep_map : Dict + The reduction step map. + + Returns + ------- + TileDict + A TileDict object containing the computed tile configuration, memory traffic, shared memory cost, + grid size, and other related parameters. + """ + td = TileDict(output_tile) + td.rstep_map = rstep_map + td.traffic, td.tile_map = self._compute_memory_traffic(output_tile) + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + if td.smem_cost > self.arch.smem_cap: + td.valid = False + return td + output_shape = self.output_nodes[0].get_space_dim() + td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) + # estimated reg usage + reg_usage = int(2 * max([np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes])) + if reg_usage > self.arch.reg_cap: + td.valid = False + return td + td.block_per_SM = min( + self.arch.max_smem_usage // max(td.smem_cost, 1), + self.arch.reg_cap // max(reg_usage, 1), + self.arch.sm_partition, + ) + td.num_wave = int(np.ceil(td.grid_size / int(td.block_per_SM * self.arch.compute_max_core))) + return td + + def check_tile_shape_isvalid(self, td: TileDict) -> bool: + """ + Checks if the tile shapes in the TileDict are valid for the nodes in this context. + + Parameters: + - td (TileDict): The TileDict object containing tile shapes and other configurations. + + Returns: + - bool: True if all tile shapes are valid, False otherwise. + """ + for node in self.ordered_nodes: + if np.prod(td.get_tile(node)) == 0: + return False + node_grid_size = np.prod([(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())]) + if node_grid_size != td.grid_size: + return False + if hasattr(node, "reduce_op") and node.reduce_op is not None and len(node.reduce_op.axis) == len(td.output_tile): + for i, tile_extent in enumerate(td.output_tile): + if node.reduce_op.axis[i].dom.extent % tile_extent: + return False + + return True + + def recommend_block_size(self, td: TileDict) -> list[int]: + """ + Recommends optimal block sizes based on the TileDict configuration. + + Parameters + ---------- + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + List[int] + A list of recommended block sizes sorted based on their score. + """ + node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] + max_block_size = functools.reduce(math.gcd, node_space_sizes) + + if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(node_space_sizes): + node_reduce_sizes = [int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes] + total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] + max_possible_size = functools.reduce(math.gcd, total_sizes) + possible_block_sizes = list( + filter( + lambda x: x % max_block_size == 0 and x <= 1024, + get_all_factors(max_possible_size), + ) + ) + possible_block_sizes = list( + filter( # either be a factor of space or cover fully cover the space + lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), + possible_block_sizes, + ) + ) + factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) + return factor_ordered + else: + possible_block_sizes = get_all_factors(max_block_size) + possible_block_sizes = list(filter(lambda x: x <= 1024, possible_block_sizes)) + factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) + return factor_ordered + + def assign_block_size(self, td: TileDict, topk=1): + """ + Assigns block sizes to the TileDict based on the recommended block sizes. + + Parameters + ---------- + td : TileDict + The TileDict object to assign block sizes to. + topk : int, optional + The number of top block sizes to consider. + + Yields + ------- + Dict + The block size assignment for the primary function node. + """ + block_size_ordered = self.recommend_block_size(td) + for block_size in block_size_ordered: + result = {} + failed = False + for node in self.ordered_nodes: + result[node] = self._assign_block_size(node, td, block_size) + if result[node] is None: + failed = True + break + if failed: + continue + else: + yield result + topk -= 1 + if topk == 0: + break + + def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): + """ + Assigns a block size to a given PrimFuncNode based on the TileDict configuration and the specified block size. + + Parameters + ---------- + node : PrimFuncNode + The node to assign the block size to. + td : TileDict + The TileDict object containing the tile configuration. + block_size : int + The block size to be assigned. + + Returns + ------- + Hint + A Hint object containing the assigned block size and other related settings. + """ + tile, rsteps = td.get_tile(node), td.get_rstep(node) + factors = factorize(block_size) + cur_threads = [1 for _ in tile] + reduce_thread = {k: 1 for k in rsteps} + ndim = len(tile) + + def _score(node, thread): # small is better + score = 0 + block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] + shape = node.propagate_inputs(block_tile) + for i, _ in enumerate(node.input_buffers): + score += np.prod(shape[i]) / self.arch.bandwidth[1] + for buffer in node.output_buffers: + score += coalesced_tensor_shape(thread, buffer.shape, 8) / self.arch.bandwidth[0] + return score + + for factor in reversed(factors): + score_map = {} + for i in range(ndim): + if cur_threads[i] >= tile[i]: + continue + if (tile[i] % (cur_threads[i] * factor)) != 0: + continue + cur_threads[i] *= factor + score_map[i] = (_score(node, cur_threads), i) + cur_threads[i] //= factor + if len(score_map) > 0: + # assign to space axis + dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) + cur_threads[dim_order[0]] *= factor + else: + # assign to reduce axis + target_ax = None + for ax, ax_len in reversed(list(rsteps.items())): + if ax_len % (reduce_thread[ax] * factor) == 0: + target_ax = ax + break + assert target_ax + reduce_thread[target_ax] *= factor + + codegen_dict = Hint() + codegen_dict.block = tile + codegen_dict.thread = cur_threads + codegen_dict.rstep = [rsteps[ax.var.name] for ax in node.raxis] + codegen_dict.reduce_thread = [reduce_thread[ax.var.name] for ax in node.raxis] + codegen_dict.cached_tensors = td.cached_tensors_map[node] + codegen_dict.rasterization_plan = self.plan_rasterization(td) + + if node.get_dtype().bits == 16: # set step=2 for 16bit case to ensure coalesced access + codegen_dict._step = [1 for _ in range(ndim)] + for i in reversed(range(ndim)): + if codegen_dict.block[i] // codegen_dict.thread[i] % 2 == 0: + codegen_dict._step[i] = 2 + break + elif node.get_dtype().bits == 8: # set step=4 for 8bit case to ensure coalesced access + codegen_dict._step = [1 for _ in range(ndim)] + for i in reversed(range(ndim)): + if codegen_dict.block[i] // codegen_dict.thread[i] % 4 == 0: + codegen_dict._step[i] = 4 + break + # Plan vectorize + codegen_dict.vectorize = self._plan_vectorize(node, td, block_size) + codegen_dict.arch = self.arch + codegen_dict.opt_shapes = node.get_tag("opt_shapes") + return codegen_dict + + def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int): + """ + Plans vectorization for a given PrimFuncNode based on the TileDict configuration and block size. + + Parameters + ---------- + node : PrimFuncNode + The node for which to plan vectorization. + td : TileDict + The TileDict object containing the tile configuration. + block_size : int + The block size used for vectorization planning. + + Returns + ------- + Dict + A dictionary mapping tensors to their vectorization size. + """ + + def is_cont(shape, vec): + if len(shape) == 0: + return vec == 1 + last = shape[-1] + if last == 1: + return is_cont(shape[0:-1], vec // last) + else: + return last % vec == 0 + + def is_shape_aligned(shape, factor): + return int(np.prod(shape)) % factor == 0 + + def is_type_allowed(dtype, vec): + return dtype.bits * vec <= 128 + + vectorize_sizes = [16, 8, 4, 2] + dtypes = node.get_reduce_inputs_dtype() + shapes = node.propagate_reduction_inputs(td.get_tile(node), td.get_rstep(node)) + vectorize_result = {} + for tensor, shape in shapes.items(): + for v in vectorize_sizes: + if is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and is_type_allowed(dtypes[tensor], v): + vectorize_result[tensor] = v + break + return vectorize_result + + def plan_rasterization(self, td: TileDict): # pylint: disable=unused-argument + """ + Plans the rasterization for the given TileDict. This function is not implemented yet. + + Parameters + ---------- + td : TileDict + The TileDict object to plan rasterization for. + + Raises + ------- + RasterRationPlan + This function is not implemented yet. + """ + return NoRasterization() diff --git a/tilelang/original/tilelang/carver/roller/policy/tensorcore.py b/tilelang/original/tilelang/carver/roller/policy/tensorcore.py new file mode 100644 index 0000000000000000000000000000000000000000..86c79ea732a374847fa5ca79e88c4508596a282e --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/policy/tensorcore.py @@ -0,0 +1,337 @@ +"""Policy for tensorcore schedule""" + +from __future__ import annotations +import tvm +import numpy as np +import logging +from ..hint import Hint, Stride, TileDict, IntrinInfo +from ..node import PrimFuncNode +from .common import coalesced_factor, factorize, get_all_factors +from .default import DefaultPolicy +from ..rasterization import NoRasterization, Rasterization2DColumn + +logger = logging.getLogger(__name__) + + +class TensorCorePolicy(DefaultPolicy): + # this is the trick for wmma. + # However, for int8 mma, the wmma_k should be 32. + wmma_k: int = 16 + pipeline_stage: int = 1 + use_async_copy: bool = False + block_reduction_depth: int | None = None + + def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str | None = None): + super()._init_with_prim_func(func, name) + self._legalize_info() + return self + + def _legalize_info(self): + pipleline_stage = self.prim_func_node.get_tag("pipeline_stage") + if pipleline_stage: + self.pipeline_stage = pipleline_stage + else: + if self.arch.compute_capability in {"sm_80", "sm_90", "sm_90a"}: + self.pipeline_stage = 2 + else: + self.pipeline_stage = 1 + use_async_copy = self.prim_func_node.get_tag("use_async_copy") + if use_async_copy: + self.use_async_copy = use_async_copy + else: + if self.arch.compute_capability in {"sm_80", "sm_90", "sm_90a"}: + self.use_async_copy = True + else: + self.use_async_copy = False + # TODO: block reduction depth is not used for now. + # As there still exists some performance issues for block reduction. + block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") + if block_reduction_depth: + self.block_reduction_depth = block_reduction_depth + + def _compute_tc_strides( + self, + node: PrimFuncNode, + tile: list[int], + rstep: dict[str, int] | None = None, + ) -> tuple[Stride, Stride, Stride]: + if rstep is None: + rstep = {} + # strides was used for shared memory padding. which is necessary for avoiding + # shared memory load bank conflict when we do not applying tensorcore layout. + shapes = node.propagate_reduction_inputs(tile, rstep) + AS_shape, BS_shape = shapes.values() + CS_shape = tile + A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis() + + # applying strides + # TODO(leiwang1999): offset should be dynamically set. we can use tag -> enable_offset to control this option.. + offset = 8 + A_high_ax = min(A_ax_m, A_ax_k) + B_high_ax = min(B_ax_n, B_ax_k) + C_high_ax = min(C_ax_m, C_ax_n) + A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax) + B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax) + C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax) + return A_stride, B_stride, C_stride + + def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): + value, cached_tensors = super().infer_node_smem_usage(td, node) + value *= self.pipeline_stage + return value, cached_tensors + + def _assign_reduce_step(self, node): + if not node.get_tag("tensorcore_config"): + return super()._assign_reduce_step(node) + # get reduce input size + target_transaction = self.arch.transaction_size[0] * 2 + # 512 bytes // type bits + reduce_input_dtype = node.get_buffer_dtype(node.block_analyzer.get_input_buffers(node.reduction_block)[0]) + basic = (target_transaction * 8) // reduce_input_dtype.bits + + result = {} + for iter_info in node.raxis: + iter_name = iter_info.var.name + iter_dom = iter_info.dom.extent + if iter_dom % 16 > 0: + result[iter_name] = 16 if iter_dom < basic else basic # for the case of padding + elif iter_dom % basic == 0: + result[iter_name] = basic + else: + return super()._assign_reduce_step(node) + return result + + def _expand_reduce_axis(self, td: TileDict): + # For tensorcore program, if we got a small tilesize, we should consider expand the reduce axis + # to improve compute efficiency. + def _check_small_tile(td: TileDict): + minimal_threadhold = 32 + for node in self.ordered_nodes: + tile = td.get_tile(node) + if any([t <= minimal_threadhold for t in tile]): + return True + return False + + if _check_small_tile(td): + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + if any([v == [] for v in all_steps.values()]): + return rstep + + def _shared_memory_usage(td: TileDict): + return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) + + def _score(rstep_id): + rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis} + score = 0 + shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, input_buffer in enumerate(input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis} + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis} + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = _shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis} + return rstep + + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map[node]) + rstep_map[node] = rstep + + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + + if self.block_reduction_depth is not None: + + def _expand_with_tags(rstep): + new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()} + return new_rstep + + rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _expand_with_tags(rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map + + return + + def get_node_reduce_step_candidates(self, node): + if not node.get_tag("tensorcore_config"): + return super().get_node_reduce_step_candidates(node) + else: + # must be a a multiple of wmma_k + return {k.var.name: [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] for k in node.raxis} + + def check_tile_shape_isvalid(self, td: TileDict): + for node in self.ordered_nodes: + if node.get_tag("tensorcore_config"): + ax_m, ax_n = node.get_tag("tensorcore_config") + block_m, block_n = ( + td.tile_map[node][ax_m], + td.tile_map[node][ax_n], + ) + # check the tile size is valid + wmma_invalid = [block_m < wmma_m or block_n < wmma_n for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()] + if all(wmma_invalid): + return False + if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): + return False + return super().check_tile_shape_isvalid(td) + + def _can_implement_layout(self, node: PrimFuncNode, td: TileDict): + # Not implemented yet + # This function is used to check whether we can implement swizzling + # layout under this tile config + return False + + def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): + if not node.get_tag("tensorcore_config"): + return super().compute_node_stride_map(node, td) + use_layout = self._can_implement_layout(node, td) + + AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), td.get_rstep(node)) + A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) + tensor_strides = {} + output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)} + tensor_strides = {} + # when connected to shared input, should use full stride without rstep + for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): + if use_layout: + continue + _ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name + # TODO(lei): should dig further for shared memory connection case. + + return output_strides, tensor_strides + + def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): + if not node.get_tag("tensorcore_config"): + return super()._assign_block_size(node, td, block_size) + ax_m, ax_n = node.get_tag("tensorcore_config") + if block_size % self.arch.warp_size != 0: + return None + tile, rsteps = td.get_tile(node), td.get_rstep(node) + warps = block_size // self.arch.warp_size + ndim = len(tile) + + wmma = self.arch.get_avaliable_tensorintrin_shapes()[-1] + wmma_tile = [1 for _ in range(ndim)] + wmma_tile[ax_m] = wmma[0] + wmma_tile[ax_n] = wmma[1] + + space = [tile[i] // wmma_tile[i] for i in range(ndim)] + if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: + # allow pad, otherwise, we can not get a valid tile shape + return None + + factors = factorize(np.prod(space) // warps) + + def _score(node, warp_tile): # small is better + score = 0 + shape = node.propagate_inputs_on_reduction(warp_tile) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, _ in enumerate(input_buffers): + score += np.prod(shape[i]) / self.arch.bandwidth[1] + return score + + warp_tile = wmma_tile.copy() + for factor in reversed(factors): + score_map = {} + for i in range(ndim): + if tile[i] % (warp_tile[i] * factor) != 0: + continue + warp_tile[i] *= factor + score_map[i] = (_score(node, warp_tile), i) + warp_tile[i] //= factor + if len(score_map) == 0: + return None + dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) + warp_tile[dim_order[0]] *= factor + + codegen_dict = Hint() + codegen_dict.block = tile + codegen_dict.warp = warp_tile + codegen_dict.use_tc = True + codegen_dict.pipeline_stage = self.pipeline_stage + codegen_dict.block_reduction_depth = self.block_reduction_depth + codegen_dict.use_async = self.use_async_copy + codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis] + codegen_dict.cached_tensors = td.cached_tensors_map[node] + codegen_dict.rasterization_plan = self.plan_rasterization(td) + + intrin_info = node.get_tag("intrin_info") + if intrin_info: + codegen_dict.intrin_info = IntrinInfo(**intrin_info) + if intrin_info["out_dtype"] in ["float32"]: + codegen_dict.shared_scope = "shared.dyn" + # smem capacity + # TODO: This is a dummy mul which avoid reusing some shared memory. + # Should be removed in the future. + if td.smem_cost > (self.arch.smem_cap): + # Tile Dict: {td.output_tile} Shared memory exceeds the static capacity + # use dynamic shared memory. + codegen_dict.shared_scope = "shared.dyn" + + codegen_dict.shared_scope = "shared.dyn" + + codegen_dict.complete_config(node) + codegen_dict.vectorize = self._plan_vectorize(node, td, block_size) + codegen_dict.arch = self.arch + codegen_dict.opt_shapes = node.get_tag("opt_shapes") + codegen_dict.tensorcore_legalization() + return codegen_dict + + def plan_rasterization(self, td: TileDict): + conditions = [] + # only support single node for now + conditions.append(len(self.ordered_nodes) > 1) + # only on Ampere+ arch + conditions.append(self.arch.compute_capability < "80") + + def _check_memory_size(): + overall_gmem_size_in_bytes: int = 0 + for node in self.ordered_nodes: + for buffer in node.input_buffers: + overall_gmem_size_in_bytes += int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8 + return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes + + conditions.append(_check_memory_size()) + if any(conditions): + return NoRasterization() + # otherwise, simply provide a block rasterization factor + raster_factor = int(self.arch.compute_max_core**0.5) + + return Rasterization2DColumn(raster_factor) diff --git a/tilelang/original/tilelang/carver/roller/rasterization.py b/tilelang/original/tilelang/carver/roller/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..ec565a1c7c29cbc1a7c11191fbb8d248a5368d32 --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/rasterization.py @@ -0,0 +1,89 @@ +"""Rasteration Plan For L2 Cache Locality""" + + +class Rasterization: + panel_width_ = None + + def __init__(self) -> None: + pass + + def get_code(self) -> list[str]: + raise NotImplementedError() + + @property + def panel_width(self): + assert self.panel_width_ is not None + return self.panel_width_ + + +class NoRasterization(Rasterization): + def __init__(self) -> None: + super().__init__() + + def __repr__(self) -> str: + return "" + + def get_code(self) -> list[str]: + return [] + + +class Rasterization2DRow(Rasterization): + """ + Rasterization by Row, each Row line width is panel_width + _________ + _________| + |_________ + __________| + """ + + def __init__(self, panel_width=4) -> None: + super().__init__() + self.panel_width_ = panel_width + + def __repr__(self) -> str: + return f"" + + def get_code(self) -> list[str]: + raise NotImplementedError() + + +class Rasterization2DColumn(Rasterization): + """ + Rasterization by Column, each column line width is panel_width + _ + | | | | + | | | | + |_| |_| + """ + + def __init__(self, panel_width=4) -> None: + super().__init__() + self.panel_width_ = panel_width + + def __repr__(self) -> str: + return f"" + + def get_device_function(self) -> str: + return """ +__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) { + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width; + const auto bz = blockIdx.z; + + dim3 blockIdx(bx, by, bz); + return blockIdx; +} + """ + + def get_code(self, panel_width: int = None) -> list[str]: + if panel_width is None: + panel_width = self.panel_width_ + return [ + self.get_device_function(), + f"const dim3 blockIdx = rasterization2DColumn({panel_width});\n", + ] diff --git a/tilelang/original/tilelang/carver/roller/shape_inference/__init__.py b/tilelang/original/tilelang/carver/roller/shape_inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd4fea05a31c6ede9ef1b5670c7c683a22bad03 --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/shape_inference/__init__.py @@ -0,0 +1 @@ +from .tir import get_analyzer_by_tir # noqa: F401 diff --git a/tilelang/original/tilelang/carver/roller/shape_inference/common.py b/tilelang/original/tilelang/carver/roller/shape_inference/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c29ae4129831ab6a4255dc054c5b6b36f9bac50c --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/shape_inference/common.py @@ -0,0 +1,65 @@ +from collections import OrderedDict + +from tvm import arith + + +class Statement: + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): + self.output = output + self.dependent_region = dependent_region + self.var_map = var_map + self.range_map = range_map + + +def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): + return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + + +class InputShapeInference: + def __init__(self, deps: list[Statement]): + self.deps = deps + + def _infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int]): + shape = shape.copy() + ana = arith.Analyzer() + for dep in reversed(self.deps): + for var, bound in zip(dep.var_map.values(), shape[dep.output]): + ana.update(var, bound) + for var, bound in dep.range_map.items(): + if var.name in rstep: + bound = arith.ConstIntBound(0, min(bound.max_value, rstep[var.name] - 1)) + ana.update(var, bound) + for name, regions in dep.dependent_region.items(): + for region in regions: + bounds = [ana.const_int_bound(index) for index in region] + if name in shape: # simply merge two bounds + bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)] + shape[name] = bounds + + for name, bounds in shape.items(): + shape[name] = [c.max_value - c.min_value + 1 for c in bounds] + return shape + + def infer(self, shape, rstep: dict[str, int] = None): + if rstep is None: + rstep = {} + if isinstance(shape, (list, tuple)): + shape = {"output0": [arith.ConstIntBound(0, val - 1) for val in shape]} + shape = self._infer(shape, rstep) + return shape + + def get_input_exprs(self, output_exprs): + result = output_exprs.copy() + ana = arith.Analyzer() + for dep in reversed(self.deps): + for var, expr in zip(dep.var_map.values(), result[dep.output]): + ana.bind(var, expr) + for var in dep.range_map: + ana.bind(var, 0) + for name, regions in dep.dependent_region.items(): + if name in result: + continue + region = regions[0] + input_expr = [ana.simplify(index) for index in region] + result[name] = input_expr + return result diff --git a/tilelang/original/tilelang/carver/roller/shape_inference/tir.py b/tilelang/original/tilelang/carver/roller/shape_inference/tir.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b11d6086720f6e24d3dbcca981b455593a6221 --- /dev/null +++ b/tilelang/original/tilelang/carver/roller/shape_inference/tir.py @@ -0,0 +1,373 @@ +from collections.abc import Mapping +from tvm.tir.schedule.schedule import BlockRV +from tvm.ir import structural_equal +from tvm import arith, tir + + +class Statement: + def __init__(self, block_analyzer, block: BlockRV): + self.block_analyzer = block_analyzer + self.block = block + # assume one tir block only has one output buffer + self.dep_name = block_analyzer.get_output_buffers(block)[0].name + self.dependent_region = _extract_dependent_region(block_analyzer, block) + + self.reverse_bound_inference = {} + + def make_reverse(self, input_name: str, input_iter: list[tir.PrimExpr]): + if len(self.block_analyzer.get_reduce_axis(self.block)) > 0: + return None + if len(self.dependent_region[input_name]) != 1: + return None + indices = self.dependent_region[input_name][0] + iter_map_range = {_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)} + iter_map_result = arith.detect_iter_map( + indices, + iter_map_range, + check_level=arith.iter_affine_map.IterMapLevel.Surjective, + simplify_trivial_iterators=False, + ) + if len(iter_map_result.errors) > 0: + return None + results = arith.iter_affine_map.inverse_affine_iter_map(iter_map_result.indices, input_iter) + output_indices = [] + for _iter in self.block_analyzer.get_spatial_axis(self.block): + if _iter.var in results: + output_indices.append(results[_iter.var]) + else: + # not Bijective mapping case + output_indices.append(tir.Var("undefined", dtype="int32") % int(_iter.dom.extent)) + return output_indices + + +def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): + return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + + +class TensorDepNode: + """ + For tensor dependency analysis. + """ + + def __init__(self, name): + self.name = name + self._next = [] + self._prev = [] + + def add_next(self, node): + self._next.append(node) + self.deduplicate(self._next) + + def add_prev(self, node): + self._prev.append(node) + self.deduplicate(self._prev) + + def deduplicate(self, lst): + seen = set() + lst[:] = [n for n in lst if not (n in seen or seen.add(n))] + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + +class DependencyAnalysis: + def __init__(self, deps): + self.deps = deps + # issue: duplicate name when we have two same ops. + self.name2dep = self._construct_unique_name2dep(deps) + self.mapping = {} # name -> TensorDepNode + + def _construct_unique_name2dep(self, deps): + """ + This is a workaround for the issue that we have two same ops' fuse case. + See https://github.com/apache/tvm/issues/16433 + """ + _names: set = set() + name2dep: Mapping = {} + for dep in deps: + output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] + base_name = output_buffer.name + if base_name not in _names: + _names.add(base_name) + else: + i = 1 + while f"{base_name}_{i}" in _names: + i += 1 + base_name = f"{base_name}_{i}" + _names.add(base_name) + name2dep[base_name] = dep + return name2dep + + def get_or_create_node(self, name): + if name not in self.mapping: + self.mapping[name] = TensorDepNode(name) + return self.mapping[name] + + def traverse_dependencies(self, compute): + if isinstance(compute, Statement): + node = self.get_or_create_node(compute.block_analyzer.get_output_buffers(compute.block)[0].name) + # Loop through input tensors + for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): + # Get the input node + input_node = self.traverse_dependencies(input_buffer) + input_node.add_next(node) + node.add_prev(input_node) + elif isinstance(compute, tir.Buffer): + node = self.get_or_create_node(compute.name) + return node + + def analyze(self): + # Starting point for traversal + for _, compute in self.name2dep.items(): + self.traverse_dependencies(compute) + + def print_dependencies(self): + for name, node in self.mapping.items(): + print(f"{name} depends on {', '.join([prev.name for prev in node._prev])}") + + def find_path_from_source(self, start_name, target_name): + """ + Finds the path (if it exists) from a starting node (source) to a target node. + Returns the path as a list of nodes. + """ + visited = set() + path = [] + if self._find_path_recursive(self.mapping[start_name], target_name, visited, path): + return path + return [] + + def _find_path_recursive(self, current_node, target_name, visited, path): + """ + Recursive helper function for find_path_from_source. + """ + if current_node.name == target_name: + path.append(current_node) + return True + + if current_node.name in visited: + return False + + visited.add(current_node.name) + path.append(current_node) + + for next_node in current_node._next: + if self._find_path_recursive(next_node, target_name, visited, path): + return True + + path.pop() + return False + + +class InputShapeInference: + def __init__(self, deps: list[Statement]): + self.deps = deps + self.target_mapping = {} + self.buffer_mapping = {} + self.reduce_axes = [] + for dep in self.deps: + for ax in dep.block_analyzer.get_reduce_axis(dep.block): + self.reduce_axes.append(ax) + self.dep_analysis = DependencyAnalysis(self.deps) + self.dep_analysis.analyze() + + def construct_dependency_target(self, targets: tuple[str]): + if targets in self.target_mapping: + return self.target_mapping[targets] + # should be buffer name instead of block name + name2dep = {dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps} + mapping = {} + input_vars = [] + for target in targets: + vars = [iter.var for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)] + input_vars.append(vars) + mapping[target] = [vars] + ana = arith.Analyzer() + + for dep in self.deps: + for name in dep.dependent_region: + if name not in mapping: + continue + dep_name = dep.dep_name + indices = mapping[name][0] + output_indices = dep.make_reverse(name, indices) + if dep_name in targets: + continue + if dep_name not in mapping: + mapping[dep_name] = [output_indices] + elif not region_exist_in_list(output_indices, mapping[dep_name]): + mapping[dep_name].append(output_indices) + + for dep in reversed(self.deps): + indices_list = mapping[dep.dep_name] + ax_vars = [iter.var for iter in dep.block_analyzer.get_spatial_axis(dep.block)] + for input_name, regions in dep.dependent_region.items(): + if input_name in targets: + continue + if input_name not in mapping: + mapping[input_name] = [] + for indices in indices_list: + for region in regions: + vmap = {k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) for k, v in zip(ax_vars, indices)} + region = [ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region] + if not region_exist_in_list(region, mapping[input_name]): + mapping[input_name].append(region) + buffers = [] + for dep in self.deps: + for buffer in dep.block_analyzer.get_buffers(dep.block): + buffers.append(buffer) + + for buffer in buffers: + self.buffer_mapping[buffer.name] = buffer + + self.target_mapping[targets] = input_vars, mapping + return input_vars, mapping + + def infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int] = None, targets=None): + if rstep is None: + rstep = {} + compute_targets = tuple(shape.keys()) + input_vars, mapping = self.construct_dependency_target(compute_targets) + ana = arith.Analyzer() + results = {} + intermediate_bind = {} + for vars, bounds in zip(input_vars, shape.values()): + for var, bound in zip(vars, bounds): + ana.update(var, bound, True) + for ax in self.reduce_axes: + # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. + if ax.var.name in rstep: + bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) + else: + bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) + ana.update(ax.var, bound, True) + + for name, regions in mapping.items(): + if targets is not None and name not in targets: + continue + if compute_targets[0:1] == compute_targets: + (compute_target,) = compute_targets + path = self.dep_analysis.find_path_from_source(name, compute_target) + if len(path) > 2: + intermediate_nodes = path[1:-1] + for node in intermediate_nodes: + iters = mapping[node.name] + if len(iters) != len(regions) or len(iters) != 1: + continue + if len(*iters) != len(*regions): + break + regions = iters + intermediate_bind[name] = compute_target + + for region in regions: + bound = [ana.const_int_bound(indice) for indice in region] + if name in results: # simply merge two bounds + bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] + results[name] = bound + else: + for region in regions: + bound = [ana.const_int_bound(indice) for indice in region] + if name in results: # simply merge two bounds + bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] + results[name] = bound + + for name, bounds in results.items(): + results[name] = [c.max_value - c.min_value + 1 for c in bounds] + return results, intermediate_bind + + def get_input_exprs(self, output_exprs): + input_vars, mapping = self.construct_dependency_target(tuple(output_exprs.keys())) + ana = arith.Analyzer() + for ax in self.reduce_axes: + ana.bind(ax.var, 0) + vmap = {} + for vars, exprs in zip(input_vars, output_exprs.values()): + for var, expr in zip(vars, exprs): + if expr.dtype != var.dtype: + expr = tir.Cast(var.dtype, expr) + vmap[var] = expr + result = {} + + for name, regions in mapping.items(): + region = regions[0] + result[name] = [ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region] + return result + + +def region_exist_in_list(a, list) -> bool: + def expr_is_same(a, b) -> bool: + if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): + return a.value == b.value + return structural_equal(a, b) + + def region_is_same(a, b) -> bool: + return all(expr_is_same(indice_a, indice_b) for indice_a, indice_b in zip(a, b)) + + return any([region_is_same(a, x) for x in list]) + + +def walk_indice(expr): + if isinstance(expr, tir.expr.BinaryOpExpr): + a = walk_indice(expr.a) + b = walk_indice(expr.b) + if a is not None and b is not None: + return expr + else: + return None + elif isinstance(expr, (tir.Var, tir.expr.ConstExpr)): + return expr + elif isinstance(expr, tir.ProducerLoad): + return None + elif isinstance(expr, tir.Cast): + a = walk_indice(expr.value) + if a is not None: + return expr + return None + elif isinstance(expr, tir.Call): + return None + else: + raise Exception(f"Unhandled node type in walk_indice(): {expr}") + + +def _extract_dependent_region(block_analyzer, block: BlockRV) -> dict[str, list[tir.PrimExpr]]: + input_buffers = block_analyzer.get_input_buffers(block) + dependent_region = {buffer.name: [] for buffer in input_buffers} + + def fvisit(x): + if not isinstance(x, tir.BufferLoad): + return + if x.buffer.name not in dependent_region: + return + index = [] + for indice, shape_limit in zip(x.indices, x.buffer.shape): + expr = walk_indice(indice) + if expr is None: + expr = tir.Var("undefined", dtype="int8") % shape_limit + if isinstance(expr, tir.IntImm) and expr.value == 0: + """for tensor ir zero dim smplification case. + for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1024), T.int64(1024)): + with T.block("T_dense"): + v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) + T.reads(A_reindex[T.int64(0), v0, v2], B_reindex[T.int64(0), v1, v2]) + T.writes(T_dense_reindex[T.int64(0), v0, v1]) + with T.init(): + T_dense_reindex[T.int64(0), v0, v1] = T.float16(0) + T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2] + For example, the T_dense_reindex has three dims, however there're only two spatial loops. + """ + continue + index.append(expr) + if not region_exist_in_list(index, dependent_region[x.buffer.name]): + dependent_region[x.buffer.name].append(index) + + stmt = block_analyzer.sch.get(block) + tir.stmt_functor.post_order_visit(stmt, fvisit=fvisit) + return dependent_region + + +def get_analyzer_by_tir(block_analyzer, args) -> InputShapeInference: + deps = [Statement(block_analyzer, block) for block in args] + + return InputShapeInference(deps) diff --git a/tilelang/original/tilelang/carver/template/__init__.py b/tilelang/original/tilelang/carver/template/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0912e02ea9e4fc6bbc6e48fff25586f7e54bb669 --- /dev/null +++ b/tilelang/original/tilelang/carver/template/__init__.py @@ -0,0 +1,9 @@ +"""Template for the TileLang Carver.""" + +from .base import BaseTemplate # noqa: F401 +from .matmul import MatmulTemplate # noqa: F401 +from .gemv import GEMVTemplate # noqa: F401 +from .elementwise import ElementwiseTemplate # noqa: F401 +from .general_reduce import GeneralReductionTemplate # noqa: F401 +from .flashattention import FlashAttentionTemplate # noqa: F401 +from .conv import ConvTemplate # noqa: F401 diff --git a/tilelang/original/tilelang/carver/template/base.py b/tilelang/original/tilelang/carver/template/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4a699fbc7dc8ec1a649d3b37e9f5bad24586e006 --- /dev/null +++ b/tilelang/original/tilelang/carver/template/base.py @@ -0,0 +1,180 @@ +# Import necessary modules and classes +from abc import ABC, abstractmethod # For defining abstract base classes +from dataclasses import dataclass, field # For defining data classes +from ..arch import ( # Import architecture-related utilities and classes + TileDevice, + is_volta_arch, + is_ampere_arch, + is_cdna_arch, + auto_infer_current_arch, +) +from ..roller.hint import Hint # Import the Hint class +from ..roller.node import OutputNode # Import the OutputNode class +from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions + + +@dataclass +class BaseTemplate(ABC): + """ + Base class template for hardware-aware configurations. + This serves as an abstract base class (ABC) that defines the structure + for subclasses implementing hardware-specific optimizations. + """ + + # The architecture of the device, inferred automatically unless explicitly set + _arch: TileDevice = field(default=auto_infer_current_arch(), init=False, repr=False) + + # The function associated with this template, initially None + _func: PrimFunc = field(default=None, init=False, repr=False) + + # The outputs nodes associated with this template, initially None + _output_nodes: list[OutputNode] = field(default=None, init=False, repr=False) + + @abstractmethod + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: + """ + Abstract method that must be implemented by subclasses. + It should return a list of hardware-aware configurations (hints) + based on the specified architecture. + + Args: + arch (TileDevice, optional): The target architecture. Defaults to None. + topk (int, optional): Number of top configurations to return. Defaults to 10. + + Returns: + List[Hint]: A list of recommended hardware-aware configurations. + """ + pass + + def with_arch(self, arch: TileDevice) -> "BaseTemplate": + """ + Sets the architecture for this template and returns itself. + + Args: + arch (TileDevice): The architecture to set. + + Returns: + BaseTemplate: The instance with the updated architecture. + """ + self._arch = arch + return self + + def has_arch(self) -> bool: + """ + Checks whether the architecture is set. + + Returns: + bool: True if the architecture is set, False otherwise. + """ + return self._arch is not None + + def is_volta_arch(self) -> bool: + """ + Checks if the current architecture is a Volta architecture. + + Returns: + bool: True if the architecture is Volta, False otherwise. + """ + return is_volta_arch(self._arch) if self._arch is not None else False + + def is_ampere_arch(self) -> bool: + """ + Checks if the current architecture is an Ampere architecture. + + Returns: + bool: True if the architecture is Ampere, False otherwise. + """ + return is_ampere_arch(self._arch) if self._arch is not None else False + + def is_cdna_arch(self) -> bool: + """ + Checks if the current architecture is a CDNA architecture. + + Returns: + bool: True if the architecture is CDNA, False otherwise. + """ + return is_cdna_arch(self._arch) if self._arch is not None else False + + def equivalent_function(self) -> PrimFunc: + """ + Returns the function associated with this template. + + Returns: + PrimFunc: The stored function. + """ + return self._func + + def initialize_function(self) -> None: + """ + Placeholder method that should be implemented by subclasses. + This method is responsible for initializing the function. + + Raises: + NotImplementedError: If not implemented in the subclass. + """ + raise NotImplementedError("initialize_function is not implemented") + + def set_function(self, func: PrimFunc) -> "BaseTemplate": + """ + Sets the function for this template and returns itself. + + Args: + func (PrimFunc): The function to associate with this template. + + Returns: + BaseTemplate: The instance with the updated function. + """ + self._func = func + return self + + def set_output_nodes(self, output_nodes: list[OutputNode]) -> "BaseTemplate": + """ + Sets the output nodes for this template and returns itself. + + Args: + output_nodes (List[OutputNode]): The output nodes to associate with this template. + + Returns: + BaseTemplate: The instance with the updated output nodes. + """ + self._output_nodes = output_nodes + return self + + def recommend_hints(self, topk: int = 10) -> list[Hint]: + """ + Provides a list of recommended hardware-aware configurations. + + Args: + topk (int, optional): Number of top configurations to return. Defaults to 10. + + Returns: + List[Hint]: A list of recommended configurations. + """ + return self.get_hardware_aware_configs(self._arch, topk) + + @property + def arch(self) -> TileDevice: + """ + Returns the current architecture. + + Returns: + TileDevice: The architecture of this template. + """ + return self._arch + + @property + def output_nodes(self) -> list[OutputNode]: + """ + Returns the output nodes associated with this template. + + Returns: + List[OutputNode]: The output nodes. + """ + return self._output_nodes + + def __post_init__(self): + """ + Post-initialization method that is called after the data class is created. + Ensures that the function is initialized. + """ + self.initialize_function() diff --git a/tilelang/original/tilelang/carver/template/conv.py b/tilelang/original/tilelang/carver/template/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..c339e589488b4db57637804647c188275f0bc77b --- /dev/null +++ b/tilelang/original/tilelang/carver/template/conv.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass +from .base import BaseTemplate +from tvm import te, tir +from ..roller import Hint +from ..utils import get_roller_hints_from_func + + +@dataclass +class ConvTemplate(BaseTemplate): + """ + A template for convolution (Conv). + + This class defines the computation for a matrix-matrix convolution + with configurable parameters such as transposition, data types, and bias addition. + + Attributes: + N (int): The number of input samples processed simultaneously in a batch. + C (int): The number of input feature maps. + H (int): The height of the input feature maps. + W (int): The width of the input feature maps. + F (int): The number of filters (kernels) applied, determining output depth. + K (int): The spatial dimensions of each convolutional filter. + S (int): The step size by which the kernel slides across the input. + D (int): The spacing between kernel elements, controlling receptive field expansion. + P (int): The number of pixels added to input borders to control output spatial dimensions. + in_dtype (str): Data type of input matrices. + out_dtype (str): Data type of output matrix. + accum_dtype (str): Data type used for accumulation. + with_bias (bool): Whether to add a bias term. + """ + + # Operation-related configuration parameters + N: int # The number of input samples processed simultaneously in a batch. + C: int # The number of input feature maps. + H: int # The height of the input feature maps. + W: int # The width of the input feature maps. + F: int # The number of filters (kernels) applied, determining output depth. + K: int # The spatial dimensions of each convolutional filter. + S: int # The step size by which the kernel slides across the input. + D: int # The spacing between kernel elements, controlling receptive field expansion. + P: int # The number of pixels added to input borders to control output spatial dimensions. + in_dtype: str = "float16" # Data type of input matrices + out_dtype: str = "float16" # Data type of output matrix + accum_dtype: str = "float16" # Data type for accumulation + with_bias: bool = False # Whether to add a bias term + + def get_hardware_aware_configs(self, arch=None, topk=10) -> list[Hint]: + """ + Retrieves optimized hardware-aware configurations. + + Args: + arch (TileDevice, optional): The target hardware architecture. + topk (int, optional): Number of top configurations to consider. + + Returns: + List[Hint]: A list of optimization hints for hardware acceleration. + """ + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=True) + return roller_hints + + def initialize_function(self) -> None: + """ + Defines and initializes the convolution computation. + + This method sets up placeholders for input matrices, computes + the convolution using TVM's compute API, + and optionally applies bias and type casting. + + Raises: + AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers. + """ + N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P + assert ( + isinstance(N, int) + and isinstance(C, int) + and isinstance(H, int) + and isinstance(W, int) + and isinstance(F, int) + and isinstance(K, int) + and isinstance(S, int) + and isinstance(D, int) + and isinstance(P, int) + ), "Only Support Integer Params" + assert N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and P > 0, "Params should be positive" + + # Load configuration parameters + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias + + # Calculate kernel dimensions and output dimensions + KH, KW = K, K + OH = (H + 2 * P - D * (KH - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (KW - 1) - 1) // S + 1 + + # Define tensor shapes + input_shape = (N, H, W, C) # NHWC format input tensor + weight_shape = (KH, KW, C, F) # HWCF format weight tensor + output_shape = (N, OH, OW, F) # NHWC format output tensor + bias_shape = (F,) # Bias vector shape + + # Create TVM placeholders for input tensors + A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input tensor + B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight tensor + Bias = te.placeholder(bias_shape, name="Bias", dtype=accum_dtype) # Bias vector + + # Define reduction axes for convolution + kh = te.reduce_axis((0, KH), name="kh") + kw = te.reduce_axis((0, KW), name="kw") + c = te.reduce_axis((0, C), name="c") + + def _compute_conv(n, h, w, f): + """ + Compute function for convolution. + + Args: + n (int): Batch index. + h (int): Output height index. + w (int): Output width index. + f (int): Output channel index. + + Returns: + Computed value for output[n, h, w, f] as a sum over reduction axes. + """ + # Calculate input positions considering stride and dilation + h_in = h * S - P + kh * D + w_in = w * S - P + kw * D + + # Check if the input position is within bounds (implicit padding with 0) + return te.sum( + te.if_then_else( + te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W), + A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype), + tir.const(0, accum_dtype), + ), + axis=[kh, kw, c], + ) + + # Compute convolution result + C = te.compute( + output_shape, + fcompute=_compute_conv, + name="C", + ) + + # Optionally apply bias addition + if with_bias: + C = te.compute( + output_shape, + lambda n, h, w, f: C[n, h, w, f] + Bias[f], + name="Bias", + ) + + # Optionally cast the output to a different type + if out_dtype != accum_dtype: + C = te.compute( + output_shape, + lambda n, h, w, f: C[n, h, w, f].astype(out_dtype), + name="D", + ) + + # Set function arguments (including bias if used) + args = [A, B, Bias, C] if self.with_bias else [A, B, C] + self.set_function(te.create_prim_func(args)) + + def params_as_dict(self): + """ + Returns the template parameters as a dictionary. + + Returns: + dict: Dictionary containing template parameter values. + """ + return { + "N": self.N, + "C": self.C, + "H": self.H, + "W": self.W, + "F": self.F, + "K": self.K, + "S": self.S, + "D": self.D, + "P": self.P, + "in_dtype": self.in_dtype, + "out_dtype": self.out_dtype, + "accum_dtype": self.accum_dtype, + "with_bias": self.with_bias, + } + + @property + def class_attributes(self): + """ + Returns the class attributes in dictionary form. + + Returns: + dict: Dictionary of class attributes. + """ + return self.params_as_dict() + + def __repr__(self) -> str: + """ + Returns a string representation of the class instance. + + Returns: + str: A formatted string representation of the class. + """ + cls_name = self.__class__.__name__ + fields = self.class_attributes + field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items()) + return f"{cls_name}({field_str})" diff --git a/tilelang/original/tilelang/carver/template/elementwise.py b/tilelang/original/tilelang/carver/template/elementwise.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd30619802913b9c4196beceff89bf9c82d5d20 --- /dev/null +++ b/tilelang/original/tilelang/carver/template/elementwise.py @@ -0,0 +1,96 @@ +# Import necessary modules +from dataclasses import dataclass # Used for defining data classes +from .base import BaseTemplate # Importing the base class for templates +from tvm import te # Importing TVM's tensor expression module +from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations +from ..roller import Hint # Importing Hint for optimization hints +from ..utils import get_roller_hints_from_func # Function to obtain optimization hints + + +@dataclass +class ElementwiseTemplate(BaseTemplate): + """ + A template for element-wise operations using TVM. + + Attributes: + shape (List[int]): The shape of the tensor. + dtype (str): The data type of the tensor (default: "float16"). + """ + + # OP Related Config + shape: list[int] = None # Shape of the tensor + dtype: str = "float16" # Data type of the tensor + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: + """ + Retrieves hardware-aware optimization configurations. + + Args: + arch (TileDevice, optional): The target hardware architecture. + topk (int, optional): Number of top configurations to consider. + + Returns: + List[Hint]: A list of optimization hints for the given architecture. + """ + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=True) + return roller_hints + + def initialize_function(self) -> None: + """ + Initializes the element-wise computation function. + + Defines a simple element-wise computation: B = A + 1, where A is an input tensor. + The computation graph is built using TVM's tensor expressions. + """ + shape, dtype = self.shape, self.dtype # Extract shape and dtype + + # Define a placeholder tensor A + A = te.placeholder(shape, name="A", dtype=dtype) + + # Define the element-wise computation (adding 1 to each element) + def _compute_elementwise(*indices): + return A[indices] + 1 + + # Define the computation for B based on A + B = te.compute( + shape, + fcompute=_compute_elementwise, # Function that defines element-wise computation + name="B", + ) + + # Store input and output tensors as function arguments + args = [A, B] + + # Create and set the computation function + self.set_function(te.create_prim_func(args)) + + def params_as_dict(self): + """ + Returns the parameters of the template as a dictionary. + + Returns: + dict: A dictionary containing shape and dtype. + """ + return {"shape": self.shape, "dtype": self.dtype} + + @property + def class_attributes(self): + """ + Returns class attributes as a dictionary. + + Returns: + dict: A dictionary representation of the class attributes. + """ + return self.params_as_dict() + + def __repr__(self) -> str: + """ + Returns a string representation of the object. + + Returns: + str: A string describing the instance with its parameters. + """ + cls_name = self.__class__.__name__ + fields = self.class_attributes + field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items()) + return f"{cls_name}({field_str})" diff --git a/tilelang/original/tilelang/carver/template/flashattention.py b/tilelang/original/tilelang/carver/template/flashattention.py new file mode 100644 index 0000000000000000000000000000000000000000..933ab9585405ee7dd790b462775219fb0ccae8b1 --- /dev/null +++ b/tilelang/original/tilelang/carver/template/flashattention.py @@ -0,0 +1,173 @@ +from dataclasses import dataclass +from .base import BaseTemplate +from tvm import te +from ..arch import TileDevice +from ..roller import Hint +from ..roller import PrimFuncNode, OutputNode, Edge +from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags + + +@dataclass +class FlashAttentionTemplate(BaseTemplate): + _output_nodes: list[OutputNode] = None + + # Operation-related configuration parameters + batch_size: int = 1 + num_heads: int = 1 + head_dim: int = 1 + seq_length: int = 1 + seq_kv_length: int = 1 + + is_causal: bool = False + + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: + """ + Retrieves optimized hardware-aware configurations. + + Args: + arch (TileDevice, optional): The target hardware architecture. + topk (int, optional): Number of top configurations to consider. + + Returns: + List[Hint]: A list of optimization hints for hardware acceleration. + """ + roller_hints = get_roller_hints_from_output_nodes(self.output_nodes, arch=arch, topk=topk) + return roller_hints + + def initialize_function(self) -> None: + """ + Defines and initializes the matrix multiplication computation. + + This method sets up placeholders for input matrices, computes + the matrix multiplication using TVM's compute API, + and optionally applies bias and type casting. + + Raises: + AssertionError: If M, N, or K are not positive integers. + """ + batch_size = self.batch_size + num_heads = self.num_heads + head_dim = self.head_dim + seq_length = self.seq_length + seq_kv_length = self.seq_kv_length + + in_dtype = self.in_dtype + out_dtype = self.out_dtype + accum_dtype = self.accum_dtype + + # Equalize the input shaps into a matmul shape + QK_B, QK_M, QK_N, QK_K = batch_size * num_heads, seq_length, seq_kv_length, head_dim + SV_B, SV_M, SV_N, SV_K = batch_size * num_heads, seq_length, head_dim, seq_kv_length + + # Define tensor shapes based on transpose flags + def create_matmul(B, M, N, K): + # Define tensor shapes based on transpose flags + input_shape = (B, M, K) + weight_shape = (B, N, K) + output_shape = (B, M, N) # Shape of output matrix C + + # Create TVM placeholders for input tensors + A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input matrix A + B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight matrix B + + # Define a reduction axis for matrix multiplication + k = te.reduce_axis((0, K), name="k") + + def _compute_matmul(b, i, j): + """ + Compute function for matrix multiplication. + + Args: + i (int): Row index. + j (int): Column index. + + Returns: + Computed value for C[i, j] as a sum over the reduction axis. + """ + A_indices = [b, i, k] + B_indices = [b, j, k] + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) + + # Compute matrix multiplication result + C = te.compute( + output_shape, + fcompute=_compute_matmul, + name="C", + ) + + # Optionally cast the output to a different type + if out_dtype != accum_dtype: + C = te.compute( + output_shape, + lambda b, i, j: C[b, i, j].astype(out_dtype), + name="D", + ) + + args = [A, B, C] + return te.create_prim_func(args) + + MMA0_prim_func = create_matmul(QK_B, QK_M, QK_N, QK_K) + MMA1_prim_func = create_matmul(SV_B, SV_M, SV_N, SV_K) + + self.set_function([MMA0_prim_func, MMA1_prim_func]) + + def create_node_from_function(func, name): + tensorized_func, tags = get_tensorized_func_and_tags(func, self.arch.target) + assert tags is not None + return PrimFuncNode(tensorized_func, name=name, tags=tags) + + node_0 = create_node_from_function(MMA0_prim_func, name="MMA0") + node_1 = create_node_from_function(MMA1_prim_func, name="MMA1") + + # connect the two nodes + edge = Edge(node_0, node_1, 0, 0) + node_0._out_edges.append(edge) + node_1.set_inputs(0, edge) + + output_nodes = [OutputNode(node_1)] + self.set_output_nodes(output_nodes) + + def params_as_dict(self): + """ + Returns the template parameters as a dictionary. + + Returns: + dict: Dictionary containing template parameter values. + """ + return { + "M": self.M, + "N": self.N, + "K": self.K, + "trans_A": self.trans_A, + "trans_B": self.trans_B, + "in_dtype": self.in_dtype, + "out_dtype": self.out_dtype, + "accum_dtype": self.accum_dtype, + "with_bias": self.with_bias, + } + + @property + def class_attributes(self): + """ + Returns the class attributes in dictionary form. + + Returns: + dict: Dictionary of class attributes. + """ + return self.params_as_dict() + + def __repr__(self) -> str: + """ + Returns a string representation of the class instance. + + Returns: + str: A formatted string representation of the class. + """ + cls_name = self.__class__.__name__ + fields = self.class_attributes + field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items()) + return f"{cls_name}({field_str})" diff --git a/tilelang/original/tilelang/carver/template/gemv.py b/tilelang/original/tilelang/carver/template/gemv.py new file mode 100644 index 0000000000000000000000000000000000000000..e7962f6ad76205f19df8df1676705ed6e6037d25 --- /dev/null +++ b/tilelang/original/tilelang/carver/template/gemv.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from .base import BaseTemplate +from tvm import te +from ..arch import TileDevice +from ..roller import Hint +from ..utils import get_roller_hints_from_func + + +@dataclass +class GEMVTemplate(BaseTemplate): + """ + A template for Generalized Matrix-Vector Multiplication (GEMV). + + This template defines the computation for a matrix-vector multiplication + with configurable parameters such as transposition, data types, and bias addition. + """ + + # Operation-related configuration parameters + N: int = None # Number of columns in matrix B (output width) + K: int = None # Number of rows in matrix B (input width) + trans_B: bool = True # Whether to transpose matrix B + in_dtype: str = "float16" # Input data type + out_dtype: str = "float16" # Output data type + accum_dtype: str = "float16" # Accumulation data type + with_bias: bool = False # Whether to add a bias term + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: + """ + Retrieves optimized hardware-aware configurations. + + Args: + arch (TileDevice, optional): The target hardware architecture. + topk (int, optional): Number of top configurations to consider. + + Returns: + List[Hint]: A list of optimization hints for hardware acceleration. + """ + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk) + return roller_hints + + def initialize_function(self) -> None: + """ + Defines and initializes the GEMV computation function. + + This method sets up placeholders for input matrices, computes + the matrix-vector multiplication using TVM's compute API, + and optionally applies bias and type casting. + """ + M: int = 1 # Fixed M value, representing a single batch dimension + N, K = self.N, self.K + + # Ensure M, N, K are valid positive integers + assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K" + assert M > 0 and N > 0 and K > 0, "M, N, K should be positive" + + # Load configuration parameters + trans_B = self.trans_B + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias + + # Define tensor shapes + input_shape = (M, K) # Shape of input matrix A + weight_shape = (K, N) if not trans_B else (N, K) # Shape of weight matrix B + output_shape = (M, N) # Shape of output matrix C + Bias_shape = (N,) # Shape of bias vector + + # Create TVM placeholders for input tensors + A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input matrix + B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight matrix + Bias = te.placeholder(Bias_shape, name="Bias", dtype=accum_dtype) # Bias vector + + # Define a reduction axis for matrix multiplication + k = te.reduce_axis((0, K), name="k") + + def _compute_matmul(i, j): + """ + Compute function for matrix-vector multiplication. + + Args: + i (int): Row index. + j (int): Column index. + + Returns: + Computed value for C[i, j] as a sum over the reduction axis. + """ + A_indices = [i, k] + B_indices = [k, j] if not trans_B else [j, k] + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) + + # Compute matrix multiplication result + C = te.compute( + output_shape, + fcompute=_compute_matmul, + name="C", + ) + + # Optionally apply bias addition + if with_bias: + C = te.compute( + output_shape, + lambda i, j: C[i, j] + Bias[j], + name="Bias", + ) + + # Optionally cast the output to a different type + if out_dtype != accum_dtype: + C = te.compute( + output_shape, + lambda i, j: C[i, j].astype(out_dtype), + name="D", + ) + + # Set function arguments (including bias if used) + args = [A, B, Bias, C] if self.with_bias else [A, B, C] + self.set_function(te.create_prim_func(args)) + + def params_as_dict(self): + """ + Returns the template parameters as a dictionary. + + Returns: + dict: Dictionary containing template parameter values. + """ + return { + "N": self.N, + "K": self.K, + "trans_B": self.trans_B, + "in_dtype": self.in_dtype, + "out_dtype": self.out_dtype, + "accum_dtype": self.accum_dtype, + "with_bias": self.with_bias, + } + + @property + def class_attributes(self): + """ + Returns the class attributes in dictionary form. + + Returns: + dict: Dictionary of class attributes. + """ + return self.params_as_dict() + + def __repr__(self) -> str: + """ + Returns a string representation of the class instance. + + Returns: + str: A formatted string representation of the class. + """ + cls_name = self.__class__.__name__ + fields = self.class_attributes + field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items()) + return f"{cls_name}({field_str})" diff --git a/tilelang/original/tilelang/carver/template/general_reduce.py b/tilelang/original/tilelang/carver/template/general_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a55157c2586afa6c6ddcaf51f80c6c985544f1 --- /dev/null +++ b/tilelang/original/tilelang/carver/template/general_reduce.py @@ -0,0 +1,122 @@ +from __future__ import annotations +from dataclasses import dataclass +from .base import BaseTemplate +from tvm import te +from ..arch import TileDevice +from ..roller import Hint +from ..utils import get_roller_hints_from_func + + +@dataclass +class GeneralReductionTemplate(BaseTemplate): + # OP Related Config + structure: str | list[str] = None + shape: list[int] = None + dtype: str = "float16" + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=False) + return roller_hints + + def initialize_function(self) -> None: + """ + Parse the structure (e.g., 'SSR'), build the TVM compute definition + with the appropriate spatial and reduce axes, and store it in self._func. + """ + assert isinstance(self.structure, str), "Structure must be a string Currently." + + if self.structure is None or self.shape is None: + raise ValueError("Must provide both `structure` and `shape`.") + if len(self.structure) != len(self.shape): + raise ValueError("`structure` length must match `shape` length.") + if not all(isinstance(s, int) and s > 0 for s in self.shape): + raise ValueError("All dimensions in `shape` must be positive integers.") + + # Separate axes into spatial vs reduce + spatial_axes = [] + reduce_axes = [] + for i, axis_type in enumerate(self.structure): + if axis_type.upper() == "S": + spatial_axes.append((i, self.shape[i])) + elif axis_type.upper() == "R": + reduce_axes.append((i, self.shape[i])) + else: + raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.") + + # Create input placeholder + A = te.placeholder(shape=self.shape, dtype=self.dtype, name="A") + + # Build a list of te.reduce_axis (for R) and the final output shape (for S). + # We'll index them in order so that the compute lambda is consistent. + # Example for SSR => 2 spatial dims (i, j), 1 reduce dim (k). + + # (1) Prepare the spatial dimensions: + # The output shape is the product of all spatial axes in the same order they appear. + # We'll construct a tuple for the final te.compute's shape. Example: (i, j). + spatial_extents = [ext for (_, ext) in spatial_axes] + + # (2) Prepare reduce axes + # e.g. (k0, (0, extent)), (k1, (0, extent)), ... + reduce_axis_objs = [] + for _, ext in reduce_axes: + reduce_axis_objs.append(te.reduce_axis((0, ext))) + + # We need to build a function that uses the correct index mapping. + # Let's define a small helper that maps from the "spatial" indices to the + # correct A[] indexing, and includes the reduce axes as well. + + # The final compute's shape is precisely the number of spatial axes in the same order. + out_shape = tuple(spatial_extents) + + # We'll create a lambda of the form: + # (i, j, ...) -> te.sum(A[i, j, k, ...], axis=[k, ...]) + # We can do this dynamically by constructing indexing for each dimension in `A`. + + def compute_func(*spatial_indices): + # spatial_indices is a tuple of the same length as spatial_axes + # We must place each spatial index into the correct dimension of `A` + # or reduce_axis. Then for the reduce axes, we use the reduce_axis_objs in order. + + # We want to build a full indexing that has length = len(self.shape). + # E.g. structure='SSR', shape=[S0, S1, R2] + # i, j -> A[i, j, k] + # where i = spatial_indices[0], j = spatial_indices[1] + + full_index = [] + spatial_iter = 0 + reduce_iter = 0 + + # Walk through the structure in order + for axis_type in self.structure: + if axis_type.upper() == "S": + # use the next spatial_indices item + full_index.append(spatial_indices[spatial_iter]) + spatial_iter += 1 + else: + # axis_type is 'R', use the next reduce_axis_obj + full_index.append(reduce_axis_objs[reduce_iter]) + reduce_iter += 1 + + # Now we do the sum: + return te.sum(A[tuple(full_index)], axis=tuple(reduce_axis_objs)) + + # Construct the output tensor with te.compute + C = te.compute(out_shape, compute_func, name="C") + + # Create a PrimFunc from placeholders + output + args = [A, C] + prim_func = te.create_prim_func(args) + self.set_function(prim_func) + + def params_as_dict(self): + return {"shape": self.shape, "dtype": self.dtype} + + @property + def class_attributes(self): + return self.params_as_dict() + + def __repr__(self) -> str: + cls_name = self.__class__.__name__ + fields = self.class_attributes + field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items()) + return f"{cls_name}({field_str})" diff --git a/tilelang/original/tilelang/carver/template/matmul.py b/tilelang/original/tilelang/carver/template/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..57c92beb75d870cf17c491427241b1b9926ffe3d --- /dev/null +++ b/tilelang/original/tilelang/carver/template/matmul.py @@ -0,0 +1,171 @@ +from dataclasses import dataclass +from .base import BaseTemplate +from tvm import te +from ..arch import TileDevice +from ..roller import Hint +from ..utils import get_roller_hints_from_func + + +@dataclass +class MatmulTemplate(BaseTemplate): + """ + A template for matrix multiplication (MatMul). + + This class defines the computation for a matrix-matrix multiplication + with configurable parameters such as transposition, data types, and bias addition. + + Attributes: + M (int): Number of rows in matrix A and matrix C. + N (int): Number of columns in matrix B and matrix C. + K (int): Number of columns in matrix A and rows in matrix B. + trans_A (bool): Whether to transpose matrix A before multiplication. + trans_B (bool): Whether to transpose matrix B before multiplication. + in_dtype (str): Data type of input matrices. + out_dtype (str): Data type of output matrix. + accum_dtype (str): Data type used for accumulation. + with_bias (bool): Whether to add a bias term. + """ + + # Operation-related configuration parameters + M: int = None # Number of rows in matrix A and matrix C + N: int = None # Number of columns in matrix B and matrix C + K: int = None # Number of columns in matrix A and rows in matrix B + trans_A: bool = False # Whether to transpose matrix A + trans_B: bool = True # Whether to transpose matrix B + in_dtype: str = "float16" # Data type of input matrices + out_dtype: str = "float16" # Data type of output matrix + accum_dtype: str = "float16" # Data type for accumulation + with_bias: bool = False # Whether to add a bias term + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: + """ + Retrieves optimized hardware-aware configurations. + + Args: + arch (TileDevice, optional): The target hardware architecture. + topk (int, optional): Number of top configurations to consider. + + Returns: + List[Hint]: A list of optimization hints for hardware acceleration. + """ + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=True) + return roller_hints + + def initialize_function(self) -> None: + """ + Defines and initializes the matrix multiplication computation. + + This method sets up placeholders for input matrices, computes + the matrix multiplication using TVM's compute API, + and optionally applies bias and type casting. + + Raises: + AssertionError: If M, N, or K are not positive integers. + """ + M, N, K = self.M, self.N, self.K + + # Ensure M, N, K are valid positive integers + assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K" + assert M > 0 and N > 0 and K > 0, "M, N, K should be positive" + + # Load configuration parameters + trans_A, trans_B = self.trans_A, self.trans_B + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias + + # Define tensor shapes based on transpose flags + input_shape = (M, K) if not trans_A else (K, M) # Shape of input matrix A + weight_shape = (K, N) if not trans_B else (N, K) # Shape of weight matrix B + output_shape = (M, N) # Shape of output matrix C + Bias_shape = (N,) # Shape of bias vector + + # Create TVM placeholders for input tensors + A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input matrix A + B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight matrix B + Bias = te.placeholder(Bias_shape, name="Bias", dtype=accum_dtype) # Bias vector + + # Define a reduction axis for matrix multiplication + k = te.reduce_axis((0, K), name="k") + + def _compute_matmul(i, j): + """ + Compute function for matrix multiplication. + + Args: + i (int): Row index. + j (int): Column index. + + Returns: + Computed value for C[i, j] as a sum over the reduction axis. + """ + A_indices = [i, k] if not trans_A else [k, i] # Adjust indexing if A is transposed + B_indices = [k, j] if not trans_B else [j, k] # Adjust indexing if B is transposed + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) + + # Compute matrix multiplication result + C = te.compute( + output_shape, + fcompute=_compute_matmul, + name="C", + ) + + # Optionally apply bias addition + if with_bias: + C = te.compute( + output_shape, + lambda i, j: C[i, j] + Bias[j], + name="Bias", + ) + + # Optionally cast the output to a different type + if out_dtype != accum_dtype: + C = te.compute( + output_shape, + lambda i, j: C[i, j].astype(out_dtype), + name="D", + ) + + # Set function arguments (including bias if used) + args = [A, B, Bias, C] if self.with_bias else [A, B, C] + self.set_function(te.create_prim_func(args)) + + def params_as_dict(self): + """ + Returns the template parameters as a dictionary. + + Returns: + dict: Dictionary containing template parameter values. + """ + return { + "M": self.M, + "N": self.N, + "K": self.K, + "trans_A": self.trans_A, + "trans_B": self.trans_B, + "in_dtype": self.in_dtype, + "out_dtype": self.out_dtype, + "accum_dtype": self.accum_dtype, + "with_bias": self.with_bias, + } + + @property + def class_attributes(self): + """ + Returns the class attributes in dictionary form. + + Returns: + dict: Dictionary of class attributes. + """ + return self.params_as_dict() + + def __repr__(self) -> str: + """ + Returns a string representation of the class instance. + + Returns: + str: A formatted string representation of the class. + """ + cls_name = self.__class__.__name__ + fields = self.class_attributes + field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items()) + return f"{cls_name}({field_str})" diff --git a/tilelang/original/tilelang/carver/utils.py b/tilelang/original/tilelang/carver/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..67db89e39178fb8e1f1848a2e2dd95928cb4d6be --- /dev/null +++ b/tilelang/original/tilelang/carver/utils.py @@ -0,0 +1,91 @@ +from __future__ import annotations +from tvm import tir, IRModule +from tvm.tir import PrimFunc +from .arch import TileDevice +from .roller.policy import TensorCorePolicy, DefaultPolicy +from .roller.hint import Hint +from .roller.node import OutputNode +from .matmul_analysis import get_tensorized_func_and_tags +import logging + +logger = logging.getLogger(__name__) + + +def get_rasterization_code(pannel_width: int = 8) -> str: + return f""" + const int MAX_BLOCK_N = {pannel_width}; + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + """ + + +def get_roller_hints_from_func( + func_or_module: tir.PrimFunc | IRModule, arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, allow_gemv: bool = False +) -> list[Hint] | None: + func = None + if isinstance(func_or_module, tir.PrimFunc): + func = func_or_module + elif isinstance(func_or_module, IRModule): + func = retrieve_func_from_module(func_or_module) + else: + raise ValueError("Not supported type: ", type(func_or_module)) + + assert func is not None, "The function should not be None" + + roller_hints = None + if tensorcore_only: + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv) + except Exception as e_msg: + logger.debug("Get tensorized func and tags failed: ", e_msg) + tags = None + if tags and tensorized_func: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + roller_hints = policy.emit_config(topk) + else: + roller_hints = None + else: + policy = DefaultPolicy.from_prim_func(func=func, arch=arch) + tensorized_func = None + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv) + except Exception as e_msg: + logger.debug("Get tensorized func and tags failed: ", e_msg) + tags = None + if tags and tensorized_func: + policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags) + roller_hints = policy.emit_config(topk) + return roller_hints + + +def get_roller_hints_from_output_nodes( + output_nodes: list[OutputNode], arch: TileDevice, topk: int = 10, extra_tags: list[str] | None = None +) -> list[Hint] | None: + assert isinstance(output_nodes, list), "The input should be a list of functions." + + lints = [] + try: + policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None) + lints = policy.emit_config(topk) + except Exception as e_msg: + logger.debug(f"Generate hints from output nodes failed: {e_msg}", "fallback to default policy") + + if len(lints) == 0: + policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None) + lints = policy.emit_config(topk) + return lints + + +def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: + if not isinstance(ir_module, IRModule): + raise ValueError("Not supported type: ", type(ir_module)) + assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule." + func = list(ir_module.functions.values())[0] + return func diff --git a/tilelang/original/tilelang/common/__init__.py b/tilelang/original/tilelang/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56735f0ac94b93d8feed90ae98383ac28e8c7b46 --- /dev/null +++ b/tilelang/original/tilelang/common/__init__.py @@ -0,0 +1 @@ +from .transform_kind import TransformKind # noqa: F401 diff --git a/tilelang/original/tilelang/common/transform_kind.py b/tilelang/original/tilelang/common/transform_kind.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e507137beac483d099d5c076481e10d9b203c1 --- /dev/null +++ b/tilelang/original/tilelang/common/transform_kind.py @@ -0,0 +1,21 @@ +# Copied from bitblas +from enum import IntEnum + + +class TransformKind(IntEnum): + NonTransform = 0 + InterWarpTransform = 1 + IntraWarpTransform = 2 + LDMatrixTransform = 3 + + def is_non_transform(self): + return self == TransformKind.NonTransform + + def is_inter_warp_transform(self): + return self == TransformKind.InterWarpTransform + + def is_intra_warp_transform(self): + return self == TransformKind.IntraWarpTransform + + def is_ld_matrix_transform(self): + return self == TransformKind.LDMatrixTransform diff --git a/tilelang/original/tilelang/contrib/__init__.py b/tilelang/original/tilelang/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03f6545f1db12864aac636f9c8817aa64abf6ee6 --- /dev/null +++ b/tilelang/original/tilelang/contrib/__init__.py @@ -0,0 +1,2 @@ +from .nvcc import compile_cuda # noqa: F401 +from .hipcc import compile_hip # noqa: F401 diff --git a/tilelang/original/tilelang/contrib/cc.py b/tilelang/original/tilelang/contrib/cc.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc459770b06452f453a76769586e63f3b6fd57e --- /dev/null +++ b/tilelang/original/tilelang/contrib/cc.py @@ -0,0 +1,437 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Util to invoke C/C++ compilers in the system.""" + +import functools +import os +import shutil +import subprocess +import platform + +# pylint: disable=invalid-name +import sys + +from tvm.base import py_str +from tvm.contrib import tar as _tar +from tvm.contrib import utils as _utils + + +def _is_linux_like(): + return sys.platform == "darwin" or sys.platform.startswith("linux") or sys.platform.startswith("freebsd") + + +def _is_windows_like(): + return sys.platform == "win32" + + +def get_cc(): + """Return the path to the default C/C++ compiler. + + Returns + ------- + out: Optional[str] + The path to the default C/C++ compiler, or None if none was found. + """ + + if not _is_linux_like(): + return None + + env_cxx = os.environ.get("CXX") or os.environ.get("CC") + if env_cxx: + return env_cxx + cc_names = ["g++", "gcc", "clang++", "clang", "c++", "cc"] + dirs_in_path = os.get_exec_path() + for cc in cc_names: + for d in dirs_in_path: + cc_path = os.path.join(d, cc) + if os.path.isfile(cc_path) and os.access(cc_path, os.X_OK): + return cc_path + return None + + +@functools.cache +def get_cplus_compiler(): + """Return the path to the default C/C++ compiler. + + Returns + ------- + out: Optional[str] + The path to the default C/C++ compiler, or None if none was found. + """ + + if not _is_linux_like(): + return None + + env_cxx = os.environ.get("CXX") or os.environ.get("CC") + if env_cxx: + return env_cxx + cc_names = ["g++", "clang++", "c++"] + dirs_in_path = os.get_exec_path() + for cc in cc_names: + for d in dirs_in_path: + cc_path = os.path.join(d, cc) + if os.path.isfile(cc_path) and os.access(cc_path, os.X_OK): + return cc_path + return None + + +def is_darwin(): + return platform.system() == "Darwin" + + +def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None): + """Create shared library. + + Parameters + ---------- + output : str + The target shared library. + + objects : List[str] + List of object files. + + options : List[str] + The list of additional options string. + + cc : Optional[str] + The compiler command. + + cwd : Optional[str] + The current working directory. + + ccache_env : Optional[Dict[str, str]] + The environment variable for ccache. Set `None` to disable ccache by default. + """ + cc = cc or get_cc() + + if _is_linux_like(): + _linux_compile(output, objects, options, cc, cwd, ccache_env, compile_shared=True) + elif _is_windows_like(): + _windows_compile(output, objects, options, cwd, ccache_env) + else: + raise ValueError("Unsupported platform") + + +def _linux_ar(output, inputs, ar): + ar = ar or "ar" + + libname = os.path.basename(output) + if not libname.startswith("lib"): + libname = "lib" + libname + temp = _utils.tempdir() + temp_output = temp.relpath(libname) + cmd = [ar, "-crs", temp_output] + + # handles the case where some input files are tar of objects + # unpack them and return the list of files inside + objects = _tar.normalize_file_list_by_unpacking_tars(temp, inputs) + + cmd += objects + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "AR error:\n" + msg += py_str(out) + msg += "\nCommand line: " + " ".join(cmd) + raise RuntimeError(msg) + + shutil.move(temp_output, output) + + +def create_staticlib(output, inputs, ar=None): + """Create static library. + + Parameters + ---------- + output : str + The target shared library. + + inputs : List[str] + List of inputs files. Each input file can be a tarball + of objects or an object file. + + ar : Optional[str] + Path to the ar command to be used + """ + + if _is_linux_like(): + return _linux_ar(output, inputs, ar) + else: + raise ValueError("Unsupported platform") + + +def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_env=None): + """Create executable binary. + + Parameters + ---------- + output : str + The target executable. + + objects : List[str] + List of object files. + + options : List[str] + The list of additional options string. + + cc : Optional[str] + The compiler command. + + cwd : Optional[str] + The urrent working directory. + + ccache_env : Optional[Dict[str, str]] + The environment variable for ccache. Set `None` to disable ccache by default. + """ + cc = cc or get_cc() + + if _is_linux_like(): + _linux_compile(output, objects, options, cc, cwd, ccache_env) + elif _is_windows_like(): + _windows_compile(output, objects, options, cwd, ccache_env) + else: + raise ValueError("Unsupported platform") + + +def get_global_symbol_section_map(path, *, nm=None) -> dict[str, str]: + """Get global symbols from a library via nm -g + + Parameters + ---------- + path : str + The library path + + nm: str + The path to nm command + + Returns + ------- + symbol_section_map: Dict[str, str] + A map from defined global symbol to their sections + """ + if nm is None: + if not _is_linux_like(): + raise ValueError("Unsupported platform") + nm = "nm" + + symbol_section_map = {} + + if not os.path.isfile(path): + raise FileNotFoundError(f"{path} does not exist") + + cmd = [nm, "-gU", path] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = "Runtime error:\n" + msg += py_str(out) + raise RuntimeError(msg) + + for line in py_str(out).split("\n"): + data = line.strip().split() + if len(data) != 3: + continue + symbol = data[-1] + section = data[-2] + symbol_section_map[symbol] = section + return symbol_section_map + + +def get_target_by_dump_machine(compiler): + """Functor of get_target_triple that can get the target triple using compiler. + + Parameters + ---------- + compiler : Optional[str] + The compiler. + + Returns + ------- + out: Callable + A function that can get target triple according to dumpmachine option of compiler. + """ + + def get_target_triple(): + """Get target triple according to dumpmachine option of compiler.""" + if compiler: + cmd = [compiler, "-dumpmachine"] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "dumpmachine error:\n" + msg += py_str(out) + return None + return py_str(out) + return None + + return get_target_triple + + +# assign so as default output format +create_shared.output_format = "so" if sys.platform != "win32" else "dll" +create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc())) + + +def cross_compiler(compile_func, options=None, output_format=None, get_target_triple=None, add_files=None): + """Create a cross compiler function by specializing compile_func with options. + + This function can be used to construct compile functions that + can be passed to AutoTVM measure or export_library. + + + Parameters + ---------- + compile_func : Union[str, Callable[[str, str, Optional[str]], None]] + Function that performs the actual compilation + + options : Optional[List[str]] + List of additional optional string. + + output_format : Optional[str] + Library output format. + + get_target_triple: Optional[Callable] + Function that can target triple according to dumpmachine option of compiler. + + add_files: Optional[List[str]] + List of paths to additional object, source, library files + to pass as part of the compilation. + + Returns + ------- + fcompile : Callable[[str, str, Optional[str]], None] + A compilation function that can be passed to export_library. + + Examples + -------- + .. code-block:: python + + from tvm.contrib import cc, ndk + # export using arm gcc + mod = build_runtime_module() + mod.export_library(path_dso, + fcompile=cc.cross_compiler("arm-linux-gnueabihf-gcc")) + # specialize ndk compilation options. + specialized_ndk = cc.cross_compiler( + ndk.create_shared, + ["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"]) + mod.export_library(path_dso, fcompile=specialized_ndk) + """ + base_options = [] if options is None else options + kwargs = {} + add_files = [] if add_files is None else add_files + + # handle case where compile_func is the name of the cc + if isinstance(compile_func, str): + kwargs = {"cc": compile_func} + compile_func = create_shared + + def _fcompile(outputs, objects, options=None): + all_options = base_options + if options is not None: + all_options += options + compile_func(outputs, objects + add_files, options=all_options, **kwargs) + + if not output_format and hasattr(compile_func, "output_format"): + output_format = compile_func.output_format + output_format = output_format if output_format else "so" + + if not get_target_triple and hasattr(compile_func, "get_target_triple"): + get_target_triple = compile_func.get_target_triple + + _fcompile.output_format = output_format + _fcompile.get_target_triple = get_target_triple + return _fcompile + + +def _linux_compile(output, objects, options, compile_cmd, cwd=None, ccache_env=None, compile_shared=False): + cmd = [compile_cmd] + if compile_cmd != "nvcc": + if compile_shared or output.endswith(".so") or output.endswith(".dylib"): + cmd += ["-shared", "-fPIC"] + if sys.platform == "darwin": + cmd += ["-undefined", "dynamic_lookup"] + elif output.endswith(".obj"): + cmd += ["-c"] + else: + if compile_shared or output.endswith(".so") or output.endswith(".dylib"): + cmd += ["-shared"] + cmd += ["-o", output] + if isinstance(objects, str): + cmd += [objects] + else: + cmd += objects + if options: + cmd += options + env = None + if ccache_env is not None: + if shutil.which("ccache"): + cmd.insert(0, "ccache") + env = os.environ.copy() + env.update(ccache_env) + else: + raise ValueError("ccache not found") + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "Compilation error:\n" + msg += py_str(out) + msg += "\nCommand line: " + " ".join(cmd) + raise RuntimeError(msg) + + +def _windows_compile(output, objects, options, cwd=None, ccache_env=None): + cmd = ["clang"] + cmd += ["-O2"] + + if output.endswith(".so") or output.endswith(".dll"): + cmd += ["-shared"] + elif output.endswith(".obj"): + cmd += ["-c"] + + if isinstance(objects, str): + objects = [objects] + cmd += ["-o", output] + cmd += objects + if options: + cmd += options + env = None + if ccache_env is not None: + if shutil.which("ccache"): + cmd.insert(0, "ccache") + env = os.environ.copy() + env.update(ccache_env) + else: + raise ValueError("ccache not found") + + try: + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) + (out, _) = proc.communicate() + except FileNotFoundError: + raise RuntimeError( + "Can not find the LLVM clang for Windows clang.exe)." + "Make sure it's installed" + " and the installation directory is in the %PATH% environment " + "variable. Prebuilt binaries can be found at: https://llvm.org/" + ) from None + if proc.returncode != 0: + msg = "Compilation error:\n" + msg += " ".join(cmd) + "\n" + msg += py_str(out) + + raise RuntimeError(msg) diff --git a/tilelang/original/tilelang/contrib/cutedsl/__init__.py b/tilelang/original/tilelang/contrib/cutedsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1028badeaacab26935754d10d0e2b3aafdebb65f --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/__init__.py @@ -0,0 +1,128 @@ +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import nvvm +from cutlass.cutlass_dsl import T + +# re-export cutlass.cute.arch functions first +from cutlass.cute.arch import sync_threads # noqa: F401 +from cutlass.cute.arch import alloc_smem, get_dyn_smem # noqa: F401 +from cutlass.cute.arch import warpgroup_reg_alloc, warpgroup_reg_dealloc # noqa: F401 + +from cutlass.cute import make_tensor, make_rmem_tensor, recast_ptr # noqa: F401 +from cutlass.cute.typing import Numeric + +from cutlass.base_dsl.typing import as_numeric, Int32, Uint16, Uint32 # noqa: F401 +from cutlass._mlir.dialects import llvm, arith # noqa: F401 +from cutlass._mlir import ir as mlir_ir +from cutlass.cutlass_dsl import dsl_user_op + +# Import our custom implementations (will override if names conflict) +from .mbar import * +from .cpasync import * +from .gemm_V1 import * +from .reduce import * +from .ldsm import * +from .math import * +from .threadblock_swizzle import * + +# Forward nvvm enums +from cutlass._mlir.dialects.nvvm import ( + MemOrderKind, + MemScopeKind, + AtomicOpKind, +) + +BYTES_PER_TENSORMAP = 128 +BYTES_PER_POINTER = 8 + + +def make_filled_tensor(shape, value): + t = cute.make_rmem_tensor(shape, type(value)) + t.fill(value) + return t + + +def make_tensor_at_offset(ptr: cute.Pointer, offset, shape, div_by=1): + if div_by != 1: + offset = cute.assume(cutlass.as_numeric(offset), divby=div_by) + return cute.make_tensor(ptr + offset, shape) + + +def shuffle_elect(thread_extent): + # thread_extent is the number of threads of a warpgroup + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + if thread_extent == 0: + return warp_idx == 0 + else: + return (warp_idx % (thread_extent // 32)) == 0 + + +def sync_thread_partial(barrier_id=None, thread_count=None): + bar_sync_ptx(barrier_id, thread_count) + + +# Packing functions +def pack_half2(x, y): + """ + Pack two half-precision (fp16) values into a single 32-bit value. + Corresponds to CUDA's __pack_half2 intrinsic. + + This packs two fp16 values into a single int32 by treating the fp16 bits + as raw data and concatenating them. + """ + + @dsl_user_op + def pack_half2_impl(x_val, y_val, *, loc=None, ip=None): + # Cast fp16 to uint16 (bitcast) + x_ir = x_val.ir_value(loc=loc, ip=ip) if hasattr(x_val, "ir_value") else x_val + y_ir = y_val.ir_value(loc=loc, ip=ip) if hasattr(y_val, "ir_value") else y_val + + # Bitcast fp16 to i16 + i16_type = mlir_ir.IntegerType.get_signless(16) + x_i16 = llvm.bitcast(i16_type, x_ir, loc=loc, ip=ip) + y_i16 = llvm.bitcast(i16_type, y_ir, loc=loc, ip=ip) + + packed_xy = llvm.inline_asm( + Int32.mlir_type, + [x_i16, y_i16], + "mov.b32 $0, {$1, $2};", + "=r,h,h", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Int32(packed_xy) + + return pack_half2_impl(x, y) + + +def AtomicAdd(ptr: cute.Pointer, value: Numeric, *, loc=None, ip=None): + if ptr.dtype == cutlass.Float32: + ret = nvvm.atomicrmw( + T.f32(), + AtomicOpKind.FADD, + ptr.llvm_ptr, + ptr.dtype(value).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU, + loc=loc, + ip=ip, + ) + elif ptr.dtype == cutlass.Int32: + ret = nvvm.atomicrmw( + T.i32(), + AtomicOpKind.ADD, + ptr.llvm_ptr, + ptr.dtype(value).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU, + loc=loc, + ip=ip, + ) + else: + raise ValueError(f"Unsupported dtype: {ptr.dtype}") + return ptr.dtype(ret) diff --git a/tilelang/original/tilelang/contrib/cutedsl/cpasync.py b/tilelang/original/tilelang/contrib/cutedsl/cpasync.py new file mode 100644 index 0000000000000000000000000000000000000000..6ddeb89337161262bf06fad92f5672fcb979f16f --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/cpasync.py @@ -0,0 +1,215 @@ +from __future__ import annotations +from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op # noqa: F401 + +from cutlass._mlir.dialects import nvvm, cute_nvgpu # noqa: F401 +from cutlass._mlir import ir + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +import cutlass.cute as cute +from cutlass.cute.typing import Int, Boolean, Int32, Int16, Uint64, Union # noqa: F401 +from cutlass.impl_utils import check_value_in + +from cutlass.cute.arch import cp_async_commit_group as cp_async_commit # noqa: F401 +from cutlass.cute.arch import cp_async_wait_group as cp_async_wait # noqa: F401 + +BYTES_PER_TENSORMAP = 128 +BYTES_PER_POINTER = 8 + + +def cp_async_gs(size, dst, dst_offset, src, src_offset): + assert size in [16, 8, 4] + # use CG (cache global) to by pass L1 when loading contiguous 128B. + mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA + if isinstance(src, cute.Tensor): + src_ptr = src.iterator + elif isinstance(src, cute.Pointer): + src_ptr = src + else: + raise ValueError(f"Invalid source type: {type(src)}") + if isinstance(dst, cute.Tensor): + dst_ptr = dst.iterator + elif isinstance(dst, cute.Pointer): + dst_ptr = dst + else: + raise ValueError(f"Invalid destination type: {type(dst)}") + cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode) + + +@cute.jit +def cp_async_gs_conditional(size, dst, dst_offset, src, src_offset, cond): + if cond: + cp_async_gs(size, dst, dst_offset, src, src_offset) + + +@dsl_user_op +def extract_tensormap_ptr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + extract the tensormap pointer from a TMA Copy Atom. + :param tma_atom: The TMA Copy Atom + :type tma_atom: CopyAtom + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + ptr_type = _cute_ir.PtrType.get(Uint64.mlir_type, _cute_ir.AddressSpace.generic, 64) + tensormap_ptr = _cute_nvgpu_ir.get_tma_desc_addr(ptr_type, exec_value, loc=loc, ip=ip) + return tensormap_ptr + + +@dsl_user_op +def tma_load(tma_desc, mbar: cute.Pointer, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None: + """ + Load data from global memory to shared memory using TMA (Tensor Memory Access). + + :param tma_desc: TMA descriptor for the tensor + :type tma_desc: CopyAtom or tensormap_ptr or Tensor of tensormap_ptr + :param mbar: Mbarrier pointer in shared memory + :type mbar: Pointer + :param smem_ptr: Destination pointer in shared memory + :type smem_ptr: Pointer + :param crd: Coordinates tuple for the tensor access + :type crd: tuple[Int, ...] + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + + if not isinstance(crd, tuple) and isinstance(tma_desc, cute.Pointer): + # Legacy signature: tma_load(smem_ptr, gmem_ptr, mbar, size) + _smem_ptr = tma_desc + _gmem_ptr = mbar + _mbar = smem_ptr + nvvm.cp_async_bulk_shared_cluster_global( + dst_mem=_smem_ptr.llvm_ptr, + src_mem=_gmem_ptr.llvm_ptr, + mbar=_mbar.llvm_ptr, + size=Int32(crd).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + else: + if isinstance(tma_desc, cute.CopyAtom): + tma_desc_ptr = extract_tensormap_ptr(tma_desc) + elif isinstance(tma_desc, cute.Tensor): + tma_desc_ptr = tma_desc.iterator + else: + tma_desc_ptr = tma_desc + nvvm.cp_async_bulk_tensor_shared_cluster_global( + dst_mem=smem_ptr.llvm_ptr, + tma_descriptor=tma_desc_ptr.llvm_ptr, + coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd], + mbar=mbar.llvm_ptr, + im2col_offsets=[], + load_mode=nvvm.CpAsyncBulkTensorLoadMode.TILE, + group=nvvm.Tcgen05GroupKind.CTA_1, + use_intrinsic=False, # set to True would lead to compile error + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_store(tma_desc, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None: + """ + Store data from shared memory to global memory using TMA (Tensor Memory Access). + + :param tma_desc: TMA descriptor for the tensor + :type tma_desc: TMA descriptor + :param smem_ptr: Source pointer in shared memory + :type smem_ptr: Pointer + :param crd: Coordinates tuple for the tensor access + :type crd: tuple[Int, ...] + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + if not isinstance(crd, tuple): + if arch not in ("sm_90", "sm_90a"): + raise NotImplementedError("tma_store(size) path is only implemented for sm_90/sm_90a") + gmem_ptr = tma_desc.align(smem_ptr.alignment) + _cute_nvgpu_ir.arch_copy_SM90_bulk_copy_s2g( + dsmem_data_addr=smem_ptr.value, + gmem_data_addr=gmem_ptr.value, + size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), crd), + loc=loc, + ip=ip, + ) + else: + if isinstance(tma_desc, cute.CopyAtom): + tma_desc_ptr = extract_tensormap_ptr(tma_desc) + elif isinstance(tma_desc, cute.Tensor): + tma_desc_ptr = tma_desc.iterator + else: + tma_desc_ptr = tma_desc + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_descriptor=tma_desc_ptr.llvm_ptr, + src_mem=smem_ptr.llvm_ptr, + coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd], + predicate=None, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_store_arrive(*, loc=None, ip=None) -> None: + """ + Indicate arrival of warp issuing TMA_STORE. + Corresponds to PTX instruction: cp.async.bulk.commit_group; + """ + nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip) + + +@dsl_user_op +def tma_store_wait(count: int, *, read=None, loc=None, ip=None) -> None: + """ + Wait for TMA_STORE operations to complete. + Corresponds to PTX instruction: cp.async.bulk.wait_group.read ; + + :param count: The number of outstanding bulk async groups to wait for + :type count: Int + """ + nvvm.cp_async_bulk_wait_group(group=count, read=read, loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_shared_global( + dst: cute.Pointer, src: cute.Pointer, cp_size: Int, modifier: nvvm.LoadCacheModifierKind, *, src_size: Int = None, loc=None, ip=None +) -> None: + """ + Asynchronously copy data from global memory to shared memory. + + :param dst: Destination pointer in shared memory + :type dst: Pointer + :param src: Source pointer in global memory + :type src: Pointer + :param size: Size of the copy in bytes + :type size: Int + :param modifier: Cache modifier + :type modifier: Int + :param cp_size: Optional copy size override + :type cp_size: Int + """ + size = src_size if src_size else cp_size + nvvm.cp_async_shared_global( + dst=dst.llvm_ptr, + src=src.llvm_ptr, + size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), size), + modifier=modifier, + cp_size=Int32(cp_size).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_descriptor(tma_desc, *, loc=None, ip=None) -> None: + """ + Prefetch a TMA descriptor. + Corresponds to PTX instruction: prefetch.tensormap; + """ + if isinstance(tma_desc, cute.CopyAtom): + tma_desc_ptr = extract_tensormap_ptr(tma_desc) + elif isinstance(tma_desc, cute.Tensor): + tma_desc_ptr = tma_desc.iterator + else: + tma_desc_ptr = tma_desc + nvvm.prefetch_tensormap(tma_desc_ptr.llvm_ptr, loc=loc, ip=ip) diff --git a/tilelang/original/tilelang/contrib/cutedsl/gemm_V1.py b/tilelang/original/tilelang/contrib/cutedsl/gemm_V1.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6cc71e9bde33d7f0f124a2edb6990019f8f063 --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/gemm_V1.py @@ -0,0 +1,569 @@ +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils # noqa: F401 +import math +import cutlass.utils.hopper_helpers as hopper_utils +from cutlass.utils import LayoutEnum +from cutlass.cute.nvgpu.warpgroup import OperandMajorMode, OperandSource, make_smem_layout_atom + + +def make_aligned_tensor(ptr: cute.Pointer, layout: cute.Layout, align_bytes: int, swizzle=False): + ptr = ptr.align(align_bytes) + if swizzle and isinstance(layout, cute.ComposedLayout): + ptr = cute.recast_ptr(ptr=ptr, swizzle_=layout.inner, dtype=ptr.dtype) + return cute.make_tensor(ptr, layout.outer) + return cute.make_tensor(ptr, layout) + + +def gemm_ss( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with both A and B from shared memory""" + if use_wgmma: + gemm = Gemm_SM90( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum) + else: + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm(A_ptr, B_ptr, C_ptr) + + +def gemm_rs( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with A from register/fragment and B from shared memory""" + if use_wgmma: + gemm = Gemm_SM90( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm.body_rs(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum) + else: + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm.body_rs(A_ptr, B_ptr, C_ptr) + + +def gemm_sr( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with A from shared memory and B from register/fragment""" + # wgmma doesn't support gemm_sr, only use SM80 + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm.body_sr(A_ptr, B_ptr, C_ptr) + + +def gemm_rr( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with both A and B from register/fragment""" + # Both operands in register, no copy needed + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + # For gemm_rr, directly call _body_impl with copy_A=False, copy_B=False + gemm._body_impl(A_ptr, B_ptr, C_ptr, copy_A=False, copy_B=False) + + +class Gemm_SM80: + _instances = {} # cache instances for the same arguments + + def __new__(cls, *args): + key = args + if key not in cls._instances: + cls._instances[key] = super().__new__(cls) + return cls._instances[key] + + # in Tilelang, trans_A == 0 or trans_B == 1 means K major + # in Cute, trans == 0 means K major + def __init__( + self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type + ): + if not hasattr(self, "initialized"): + self.cta_tiler = (M, N, K) + self.mma_inst_shape = (16, 8, 16) + self.trans_A = trans_A != 0 # same with Tilelang + self.trans_B = trans_B == 0 # inverse with Tilelang + A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR + B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR + self.A_layout = self._make_smem_layout_AB(A_type, A_major_mode, 128, (M, K)) + self.B_layout = self._make_smem_layout_AB(B_type, B_major_mode, 128, (N, K)) + self.ab_dtype = A_type + self.acc_dtype = C_type + self.tiled_mma = self._make_tiled_mma(warp_m, warp_n) + self.clear_accum = clear_accum + + def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): + is_row_major = major_mode == LayoutEnum.ROW_MAJOR + major_mode_size = smem_tiler[1] if is_row_major else smem_tiler[0] + major_mode_size = 64 if major_mode_size >= 64 else major_mode_size + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if is_row_major + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + layout_atom_outer, + ) + layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1) if is_row_major else (1, 0)) + return layout + + def _make_tiled_mma(self, warp_m, warp_n): + atom_layout_mnk = (warp_m, warp_n, 1) + op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_inst_shape) + permutation_mnk = ( + atom_layout_mnk[0] * self.mma_inst_shape[0], + atom_layout_mnk[1] * self.mma_inst_shape[1] * 2, + atom_layout_mnk[2] * self.mma_inst_shape[2], + ) + tiled_mma = cute.make_tiled_mma(op, atom_layout_mnk, permutation_mnk) + return tiled_mma + + @cute.jit + def __call__( + self, + sA_ptr: cute.Pointer, + sB_ptr: cute.Pointer, + rC_ptr: cute.Pointer, + ): + """GEMM body: both A and B from shared memory""" + self._body_impl(sA_ptr, sB_ptr, rC_ptr, copy_A=True, copy_B=True) + + @cute.jit + def body_rs( + self, + rA_ptr: cute.Pointer, # A already in register + sB_ptr: cute.Pointer, # B from shared memory + rC_ptr: cute.Pointer, + ): + """GEMM body_rs: A from register, B from shared memory""" + self._body_impl(rA_ptr, sB_ptr, rC_ptr, copy_A=False, copy_B=True) + + @cute.jit + def body_sr( + self, + sA_ptr: cute.Pointer, # A from shared memory + rB_ptr: cute.Pointer, # B already in register + rC_ptr: cute.Pointer, + ): + """GEMM body_sr: A from shared memory, B from register""" + self._body_impl(sA_ptr, rB_ptr, rC_ptr, copy_A=True, copy_B=False) + + @cute.jit + def _body_impl( + self, + A_ptr: cute.Pointer, + B_ptr: cute.Pointer, + rC_ptr: cute.Pointer, + copy_A: cutlass.Constexpr = True, + copy_B: cutlass.Constexpr = True, + ): + """Internal implementation with configurable copy operations""" + tidx, _, _ = cute.arch.thread_idx() + thr_mma = self.tiled_mma.get_slice(tidx) + + tCrA = None + tCrB = None + tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1]))) + + # Create copy operations only for operands that need copying + if cutlass.const_expr(copy_A): + sA = make_aligned_tensor(A_ptr, self.A_layout, 16) + tCsA = thr_mma.partition_A(sA) + tCrA = self.tiled_mma.make_fragment_A(tCsA) + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_A, 4), + sA.element_type, + ) + tiled_copy_s2r_A = cute.make_tiled_copy( + atom_copy_s2r_A, + layout_tv=self.tiled_mma.tv_layout_A_tiled, + tiler_mn=(self.tiled_mma.get_tile_size(0), self.tiled_mma.get_tile_size(2)), + ) + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + else: + # A already in register + tCrA = cute.make_tensor(A_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2]))) + + if cutlass.const_expr(copy_B): + sB = make_aligned_tensor(B_ptr, self.B_layout, 16) + tCsB = thr_mma.partition_B(sB) + tCrB = self.tiled_mma.make_fragment_B(tCsB) + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_B, 4), + sB.element_type, + ) + tiled_copy_s2r_B = cute.make_tiled_copy( + atom_copy_s2r_B, + layout_tv=self.tiled_mma.tv_layout_B_tiled, + tiler_mn=(self.tiled_mma.get_tile_size(1), self.tiled_mma.get_tile_size(2)), + ) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + else: + # B already in register + tCrB = cute.make_tensor(B_ptr, self.tiled_mma.partition_shape_B((self.cta_tiler[1], self.cta_tiler[2]))) + + if self.clear_accum: + tCrC.fill(0) + + for k in cutlass.range(cute.size(tCrA, mode=[2])): + if cutlass.const_expr(copy_A): + cute.copy(tiled_copy_s2r_A, tCsA_copy_view[None, None, k], tCrA_copy_view[None, None, k]) + if cutlass.const_expr(copy_B): + cute.copy(tiled_copy_s2r_B, tCsB_copy_view[None, None, k], tCrB_copy_view[None, None, k]) + cute.gemm(self.tiled_mma, tCrC, tCrA[None, None, k], tCrB[None, None, k], tCrC) + + +class Gemm_SM90: + _instances = {} # cache instances for the same arguments + + def __new__(cls, *args): + key = args + if key not in cls._instances: + cls._instances[key] = super().__new__(cls) + return cls._instances[key] + + # in Tilelang, trans_A == 0 or trans_B == 1 means K major + # in Cute, trans == 0 means K major + def __init__( + self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type + ): + if not hasattr(self, "initialized"): + self.cta_tiler = (M, N, K) + self.tiler_mn = (M, N) + self.atom_layout_mnk = (warp_m // 4, warp_n, 1) + self.trans_A = trans_A != 0 # same with Tilelang + self.trans_B = trans_B == 0 # inverse with Tilelang + self.a_leading_mode = OperandMajorMode.MN if self.trans_A else OperandMajorMode.K + self.b_leading_mode = OperandMajorMode.MN if self.trans_B else OperandMajorMode.K + A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR + B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR + self.A_layout = self.make_smem_layout_AB(A_type, A_major_mode, (M, K)) + self.B_layout = self.make_smem_layout_AB(B_type, B_major_mode, (N, K)) + self.a_dtype = A_type + self.b_dtype = B_type + self.acc_dtype = C_type + self.tiled_mma = None + self.A_source = None + self.clear_accum = clear_accum + + @staticmethod + def make_tma_atom( + tensor, + smem_layout_staged, + smem_tile, + mcast_dim, + ): + op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() if mcast_dim == 1 else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + + return tma_atom + + @staticmethod + def get_tma_atom(tensor, tiler_mk, stages=1): + smem_layout_staged = Gemm_SM90.make_smem_layout_AB(tensor.element_type, LayoutEnum.from_tensor(tensor), tiler_mk, stages) + tma_atom = Gemm_SM90.make_tma_atom(tensor, smem_layout_staged, tiler_mk, 1) + return tma_atom + + @staticmethod + def make_smem_layout_AB(dtype, major_mode: LayoutEnum, tiler_mk, stages=1): + smem_shape = tiler_mk + # Determine if K is the major mode and get the major mode size + is_k_major = major_mode.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + major_mode_size = tiler_mk[1] if is_k_major else tiler_mk[0] + + # Create SMEM layout atom for A tensor based on major mode and data type + smem_layout_atom = make_smem_layout_atom( + hopper_utils.get_smem_layout_atom(major_mode, dtype, major_mode_size), + dtype, + ) + # Tile the SMEM layout atom to the A tensor shape and add staging dimension + smem_layout = cute.tile_to_shape(smem_layout_atom, cute.append(smem_shape, stages), order=(0, 1, 2) if is_k_major else (1, 0, 2)) + return smem_layout + + def _make_tiled_mma(self, is_rsMode=False): + tiled_mma = hopper_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_leading_mode, + self.b_leading_mode, + self.acc_dtype, + self.atom_layout_mnk, + (64, self.tiler_mn[1] // self.atom_layout_mnk[1]), + OperandSource.SMEM if not is_rsMode else OperandSource.RMEM, + ) + return tiled_mma + + @cute.jit + def __call__( + self, + sA_ptr: cute.Pointer, + sB_ptr: cute.Pointer, + rC_ptr: cute.Pointer, + wg_wait: cutlass.Constexpr = 0, + clear_accum: cutlass.Constexpr = False, + ): + tidx, _, _ = cute.arch.thread_idx() + self.tiled_mma = self._make_tiled_mma() + thr_mma = self.tiled_mma.get_slice(tidx) + + sA_ptr = cute.recast_ptr(sA_ptr, self.A_layout.inner, dtype=sA_ptr.dtype) + sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype) + sA = cute.make_tensor(sA_ptr, self.A_layout.outer) + sB = cute.make_tensor(sB_ptr, self.B_layout.outer) + + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + + tCrA = self.tiled_mma.make_fragment_A(tCsA) + tCrB = self.tiled_mma.make_fragment_B(tCsB) + tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1]))) + + cute.nvgpu.warpgroup.fence() + if cutlass.const_expr(clear_accum): + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + else: + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k in cutlass.range(num_k_blocks): + tCrA_1phase = tCrA[None, None, k, 0] + tCrB_1phase = tCrB[None, None, k, 0] + cute.gemm(self.tiled_mma, tCrC, tCrA_1phase, tCrB_1phase, tCrC) + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(wg_wait >= 0): + cute.nvgpu.warpgroup.wait_group(wg_wait) + + @cute.jit + def body_rs( + self, + rA_ptr: cute.Pointer, # A already in register (Fragment) + sB_ptr: cute.Pointer, # B from shared memory + rC_ptr: cute.Pointer, + wg_wait: cutlass.Constexpr = 0, + clear_accum: cutlass.Constexpr = False, + ): + """ + GEMM body_rs for SM90/Hopper: A from register, B from shared memory. + Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h + """ + tidx, _, _ = cute.arch.thread_idx() + self.tiled_mma = self._make_tiled_mma(is_rsMode=True) + # if self.A_source != OperandSource.RMEM or self.tiled_mma is None: + # self.tiled_mma = self._make_tiled_mma(is_rsMode = True) + # self.A_source = OperandSource.RMEM + # B from shared memory (with swizzle) + sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype) + sB = cute.make_tensor(sB_ptr, self.B_layout.outer) + + # Use the existing tiled_mma + thr_mma = self.tiled_mma.get_slice(tidx) + + # Partition B from shared memory - standard path + tCsB = thr_mma.partition_B(sB) + tCrB = self.tiled_mma.make_fragment_B(tCsB) + + # A already in register + # For body_rs, A is NOT partitioned through thr_mma (it's already partitioned) + # We create the tensor directly with the full shape + # This matches C++: make_tensor(make_rmem_ptr(pA), partition_shape_A(...)) + tCrA = cute.make_tensor(rA_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2]))) + + # C accumulator + tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1]))) + + # Fence operands (prepare for wgmma) + cute.nvgpu.warpgroup.fence() + # Note: warpgroup_arrive() is called internally by wgmma + # Set accumulation mode + if cutlass.const_expr(clear_accum): + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + else: + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + # GEMM loop + num_k_blocks = cute.size(tCrB, mode=[2]) + for k_block in cutlass.range(num_k_blocks): + # Match the indexing pattern from __call__ + # If tCrB has 4 dimensions (with pipeline), use [None, None, k, 0] + # Otherwise use [None, None, k] + tCrB_k = tCrB[None, None, k_block, 0] if cute.rank(tCrB) >= 4 else tCrB[None, None, k_block] + tCrA_k = tCrA[None, None, k_block, 0] if cute.rank(tCrA) >= 4 else tCrA[None, None, k_block] + cute.gemm(self.tiled_mma, tCrC, tCrA_k, tCrB_k, tCrC) + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(wg_wait >= 0): + cute.nvgpu.warpgroup.wait_group(wg_wait) diff --git a/tilelang/original/tilelang/contrib/cutedsl/ldsm.py b/tilelang/original/tilelang/contrib/cutedsl/ldsm.py new file mode 100644 index 0000000000000000000000000000000000000000..4f36026975e3e7337a9c1a6eb9dc6f838a55f25c --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/ldsm.py @@ -0,0 +1,127 @@ +""" +LDMATRIX and STMATRIX operations for CuTeDSL backend. +Based on tl_templates/cuda/ldsm.h + +These functions provide wrappers around PTX ldmatrix/stmatrix instructions +for loading/storing 8x8 matrix fragments between shared memory and registers. +""" + +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass._mlir import ir # noqa: F401 +from cutlass.cute.typing import Pointer, Int32 # noqa: F401 +import cutlass.cute as cute + + +def _to_ir_value(v, loc=None, ip=None): + """Convert value to MLIR IR, handling both cutlass types and raw MLIR Values""" + if hasattr(v, "ir_value"): + return v.ir_value(loc=loc, ip=ip) + else: + # Already an MLIR Value + return v + + +def _ldmatrix(smem_ptr, local_ptr, num, transpose, loc=None, ip=None): + """Internal helper for ldmatrix operations""" + layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row + assert num in [2, 4] + ret_type = llvm.StructType.get_literal([T.i32()] * num) + out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), num) + for i in range(num): + out[i] = cute.Int32(llvm.extractvalue(T.i32(), out_i32, [i], loc=loc, ip=ip)) + + +def _stmatrix(smem_ptr, values, transpose, loc=None, ip=None): + """Internal helper for stmatrix operations""" + layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row + ir_values = [_to_ir_value(v, loc, ip) for v in values] + nvvm.stmatrix(smem_ptr.llvm_ptr, ir_values, layout=layout, loc=loc, ip=ip) + + +# ============================================================================ +# LDMATRIX operations (load from shared memory to registers) +# ============================================================================ + + +@dsl_user_op +def ptx_ldmatrix_x1(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 1 matrix (8x8) from shared memory""" + # _ldmatrix(smem_ptr, local_ptr, 1, False, loc, ip) + out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.row, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1) + out[0] = cute.Int32(out_i32) + + +@dsl_user_op +def ptx_ldmatrix_x2(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 2 matrices (8x8 each) from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 2, False, loc, ip) + + +@dsl_user_op +def ptx_ldmatrix_x4(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 4 matrices (8x8 each) from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 4, False, loc, ip) + + +@dsl_user_op +def ptx_ldmatrix_x1_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 1 matrix (8x8) with transpose from shared memory""" + out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.col, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1) + out[0] = cute.Int32(out_i32) + + +@dsl_user_op +def ptx_ldmatrix_x2_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 2 matrices (8x8 each) with transpose from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 2, True, loc, ip) + + +@dsl_user_op +def ptx_ldmatrix_x4_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 4 matrices (8x8 each) with transpose from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 4, True, loc, ip) + + +# ============================================================================ +# STMATRIX operations (store from registers to shared memory) +# ============================================================================ + + +@dsl_user_op +def ptx_stmatrix_x1(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None: + """Store 1 matrix (8x8) to shared memory""" + _stmatrix(smem_ptr, [value0], False, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x2(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None: + """Store 2 matrices (8x8 each) to shared memory""" + _stmatrix(smem_ptr, [value0, value1], False, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x4(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None: + """Store 4 matrices (8x8 each) to shared memory""" + _stmatrix(smem_ptr, [value0, value1, value2, value3], False, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x1_trans(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None: + """Store 1 matrix (8x8) with transpose to shared memory""" + _stmatrix(smem_ptr, [value0], True, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x2_trans(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None: + """Store 2 matrices (8x8 each) with transpose to shared memory""" + _stmatrix(smem_ptr, [value0, value1], True, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x4_trans(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None: + """Store 4 matrices (8x8 each) with transpose to shared memory""" + _stmatrix(smem_ptr, [value0, value1, value2, value3], True, loc, ip) diff --git a/tilelang/original/tilelang/contrib/cutedsl/math.py b/tilelang/original/tilelang/contrib/cutedsl/math.py new file mode 100644 index 0000000000000000000000000000000000000000..3f775091bee84a089eefce999d26a7c634f161e5 --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/math.py @@ -0,0 +1,9 @@ +import cutlass.cute as cute +from cutlass.cute.typing import Union, Numeric +from cutlass.cute.tensor import TensorSSA +from cutlass._mlir.dialects import arith +from cutlass.cute.math import exp, exp2, log, log2, log10, tan, cos, sin, sqrt # noqa: F401 + + +def divf(x: Union[TensorSSA, Numeric], y: Union[TensorSSA, Numeric], fastmath: bool = False) -> Union[TensorSSA, Numeric]: + return cute.math._math_op(arith.divf, fastmath, x, y) diff --git a/tilelang/original/tilelang/contrib/cutedsl/mbar.py b/tilelang/original/tilelang/contrib/cutedsl/mbar.py new file mode 100644 index 0000000000000000000000000000000000000000..ca956e2f499c316db2c9c11ec12e8be58b46d7f5 --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/mbar.py @@ -0,0 +1,45 @@ +""" +Simple wrappers that delegate to cutlass.cute.arch implementations. +We use the existing implementations from cutlass rather than reinventing the wheel. +""" + +from cutlass.cute.typing import Pointer, Int, Int32, Boolean # noqa: F401 +from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op # noqa: F401 +from cutlass._mlir.dialects import nvvm + +from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive # noqa: F401 +from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx # noqa: F401 +from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc # noqa: F401 + +import cutlass.cute.arch as arch + + +@dsl_user_op +def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000, *, loc=None, ip=None) -> None: + """Waits on a mbarrier with a specified phase.""" + nvvm.mbarrier_try_wait_parity_shared( + mbar_ptr.llvm_ptr, + Int32(phase).ir_value(loc=loc, ip=ip), + Int32(timeout_ns).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_cp_async_arrive(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: + mbar_llvm_ptr = mbar_ptr.llvm_ptr + nvvm.cp_async_mbarrier_arrive_shared( + mbar_llvm_ptr, + noinc=False, + loc=loc, + ip=ip, + ) + + +def fence_proxy_async(): + arch.fence_proxy(arch.ProxyKind.async_shared, space=arch.SharedSpace.shared_cta) + + +def fence_barrier_init(): + arch.mbarrier_init_fence() diff --git a/tilelang/original/tilelang/contrib/cutedsl/reduce.py b/tilelang/original/tilelang/contrib/cutedsl/reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..f835b149b625b0bfe8cb94b22876bfaf768889be --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/reduce.py @@ -0,0 +1,186 @@ +""" +Reduce operations for CuTeDSL backend. +Based on tl_templates/cuda/reduce.h +""" + +from __future__ import annotations + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32, Float32 +from cutlass.cutlass_dsl import dsl_user_op, T +from cutlass._mlir.dialects import nvvm +from cutlass.cute.arch.nvvm_wrappers import shuffle_sync_op + + +@dsl_user_op +def min(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32: + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def max(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +class SumOp: + """Sum reduction operator""" + + @staticmethod + def __call__(x, y): + return x + y + + +class MaxOp: + """Max reduction operator""" + + @staticmethod + def __call__(x, y): + return max(x, y) + + +class MinOp: + """Min reduction operator""" + + @staticmethod + def __call__(x, y): + # Use cutlass.min which is JIT-friendly + return min(x, y) + + +class BitAndOp: + """Bitwise AND reduction operator""" + + @staticmethod + def __call__(x, y): + return x & y + + +class BitOrOp: + """Bitwise OR reduction operator""" + + @staticmethod + def __call__(x, y): + return x | y + + +class BitXorOp: + """Bitwise XOR reduction operator""" + + @staticmethod + def __call__(x, y): + return x ^ y + + +def bar_sync(barrier_id, number_of_threads): + cute.arch.barrier(barrier_id=barrier_id, number_of_threads=number_of_threads) + + +def bar_sync_ptx(barrier_id, number_of_threads): + from cutlass._mlir.dialects import llvm + + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(number_of_threads).ir_value()], + "bar.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def AllReduce(reducer, threads, scale, thread_offset, all_threads=None): + """ + AllReduce operation implementing warp/block-level reduction. + Based on tl::AllReduce from reduce.h + + Args: + reducer: Reducer operator class (SumOp, MaxOp, etc.) + threads: Number of threads participating in reduction + scale: Reduction scale factor + thread_offset: Thread ID offset + all_threads: Total number of threads in block + + Returns: + A callable object with run() and run_hopper() methods + """ + + class AllReduceInstance: + def __init__(self, reducer, threads, scale, thread_offset: cutlass.Constexpr[int], all_threads: cutlass.Constexpr[int]): + self.reducer = reducer + self.threads = threads + self.scale = scale + self.thread_offset = thread_offset + self.all_threads = all_threads if all_threads is not None else threads + + def run(self, x, red_buf: cute.Pointer = None): + """ + Perform all-reduce across threads. + Based on tl::AllReduce<...>::run from reduce.h + """ + offset = self.threads // 2 + + if offset >= 32: + # Use shared memory for large thread counts + cute.arch.sync_threads() + tidx, _, _ = cute.arch.thread_idx() + cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x + cute.arch.sync_threads() + x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0]) + else: + # Use warp shuffle for small thread counts + # Use the pre-existing shuffle_sync_op with butterfly (XOR) mode + other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly) + x = self.reducer()(x, other) + + return ( + x + if offset == self.scale + else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf) + ) + + def run_hopper(self, x, red_buf: cute.Pointer = None): + """ + Perform all-reduce on Hopper architecture using bar.sync. + Based on tl::AllReduce<...>::run_hopper from reduce.h + """ + offset = self.threads // 2 + tidx, _, _ = cute.arch.thread_idx() + if offset >= 32: + # Use inlined asm for bar.sync to avoid instruction reordering + bar_sync_ptx(1, self.all_threads) + cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x + bar_sync_ptx(2, self.all_threads) + x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0]) + else: + # Use warp shuffle for small thread counts + # Use the pre-existing shuffle_sync_op with butterfly (XOR) mode + other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly) + x = self.reducer()(x, other) + + return ( + x + if offset == self.scale + else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run_hopper(x, red_buf) + ) + + return AllReduceInstance(reducer, threads, scale, thread_offset, all_threads) diff --git a/tilelang/original/tilelang/contrib/cutedsl/threadblock_swizzle.py b/tilelang/original/tilelang/contrib/cutedsl/threadblock_swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce78eb86061e9cec465d9a6754cf7af1edfbf05 --- /dev/null +++ b/tilelang/original/tilelang/contrib/cutedsl/threadblock_swizzle.py @@ -0,0 +1,54 @@ +import cutlass.cute as cute +from cutlass.cute.typing import Constexpr +from dataclasses import dataclass + + +@dataclass(frozen=True) +class dim3: + x: int + y: int + z: int + + +def ThreadIdx() -> dim3: + return dim3(*cute.arch.thread_idx()) + + +def BlockIdx() -> dim3: + return dim3(*cute.arch.block_idx()) + + +def GridDim() -> dim3: + return dim3(*cute.arch.grid_dim()) + + +@cute.jit +def rasterization2DRow(panel_width: Constexpr[int]) -> dim3: + blockIdx = BlockIdx() + gridDim = GridDim() + block_idx = blockIdx.x + blockIdx.y * gridDim.x + grid_size = gridDim.x * gridDim.y + panel_size = panel_width * gridDim.x + panel_offset = block_idx % panel_size + panel_idx = block_idx // panel_size + total_panel = cute.ceil_div(grid_size, panel_size) + stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.x + col_idx = (gridDim.x - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride) + row_idx = panel_offset % stride + panel_idx * panel_width + return dim3(col_idx, row_idx, blockIdx.z) + + +@cute.jit +def rasterization2DColumn(panel_width: Constexpr[int]) -> dim3: + blockIdx = BlockIdx() + gridDim = GridDim() + block_idx = blockIdx.x + blockIdx.y * gridDim.x + grid_size = gridDim.x * gridDim.y + panel_size = panel_width * gridDim.y + panel_offset = block_idx % panel_size + panel_idx = block_idx // panel_size + total_panel = cute.ceil_div(grid_size, panel_size) + stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.y + row_idx = (gridDim.y - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride) + col_idx = panel_offset % stride + panel_idx * panel_width + return dim3(col_idx, row_idx, blockIdx.z) diff --git a/tilelang/original/tilelang/contrib/dlpack.py b/tilelang/original/tilelang/contrib/dlpack.py new file mode 100644 index 0000000000000000000000000000000000000000..d80f0fdbc39302ec687893fb6d92c25406a657c8 --- /dev/null +++ b/tilelang/original/tilelang/contrib/dlpack.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Wrapping functions to bridge frameworks with DLPack support to TVM""" + +from tvm import runtime + + +def convert_func(tvm_func, tensor_type, to_dlpack_func): + """Convert a tvm function into one that accepts a tensor from another + framework, provided the other framework supports DLPACK + + Parameters + ---------- + tvm_func: Function + Built tvm function operating on arrays + + tensor_type: Type + Type of the tensors of the target framework + + to_dlpack_func: Function + Function to convert the source tensors to DLPACK + """ + assert callable(tvm_func) + import torch + + float8_dtype_map = { + torch.float8_e4m3fn: "float8_e4m3", + torch.float8_e4m3fnuz: "float8_e4m3fnuz", + torch.float8_e5m2: "float8_e5m2", + torch.float8_e5m2fnuz: "float8_e5m2", + } + + def adapt_tensor(arg): + if isinstance(arg, tensor_type): + if arg.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: + return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype_map[arg.dtype]) + return runtime.from_dlpack(to_dlpack_func(arg)) + return arg + + def _wrapper(*args): + args = tuple(adapt_tensor(arg) for arg in args) + return tvm_func(*args) + + return _wrapper diff --git a/tilelang/original/tilelang/contrib/hipcc.py b/tilelang/original/tilelang/contrib/hipcc.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7f9f9479e54dcd03d0b95779ce65fef59e448f --- /dev/null +++ b/tilelang/original/tilelang/contrib/hipcc.py @@ -0,0 +1,98 @@ +# pylint: disable=invalid-name +"""Utility to invoke hipcc compiler in the system""" +# File is copied from a modified version of hipcc.py to support +# compilation of HIP code with hipcc compiler +# Source Path: +# https://github1s.com/TileLang/tvm/blob/upstream/python/tvm/contrib/hipcc.py + +from __future__ import absolute_import as _abs + +import subprocess + +import tvm_ffi + +from tvm.contrib import utils +from tvm.base import py_str +from tvm.contrib.rocm import get_rocm_arch, find_rocm_path + + +def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False): + """Compile HIP code with hipcc. + + Parameters + ---------- + code : str + The HIP code. + + target_format : str + The target format of hipcc compiler. + + arch : str + The AMD GPU architecture. + + options : str or list of str + The additional options. + + path_target : str, optional + Output file. + + Return + ------ + hsaco : bytearray + The bytearray of the hsaco + """ + if arch is None: + rocm_path = find_rocm_path() + arch = get_rocm_arch(rocm_path) + + temp = utils.tempdir() + if target_format not in ["hsaco"]: + raise ValueError("target_format must be hsaco") + temp_code = temp.relpath("my_kernel.cc") + temp_target = temp.relpath(f"my_kernel.{target_format}") + + with open(temp_code, "w") as out_file: + out_file.write(code) + + file_target = path_target if path_target else temp_target + cmd = ["hipcc"] + cmd += ["-O3", "-c"] + if isinstance(arch, str): + cmd += [f"--offload-arch={arch}"] + if target_format == "hsaco": + cmd += ["--genco"] + if options: + if isinstance(options, str): + cmd += [options] + elif isinstance(options, list): + cmd += options + else: + raise ValueError("options must be str or list of str") + + cmd += ["-o", file_target] + cmd += [temp_code] + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + (out, _) = proc.communicate() + if verbose: + print(py_str(out)) + + if proc.returncode != 0: + msg = code + msg += "\nCompilation error:\n" + msg += py_str(out) + raise RuntimeError(msg) + + with open(file_target, "rb") as f: + data = bytearray(f.read()) + if not data: + raise RuntimeError("Compilation error: empty result is generated") + return data + + +@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True) +def tilelang_callback_hip_compile(code, target): + """use hipcc to generate fatbin code for better optimization""" + hsaco = compile_hip(code, target_format="hsaco") + return hsaco diff --git a/tilelang/original/tilelang/contrib/nvcc.py b/tilelang/original/tilelang/contrib/nvcc.py new file mode 100644 index 0000000000000000000000000000000000000000..36df6c875e198e2770fc3a449f0bc652fc3bb75e --- /dev/null +++ b/tilelang/original/tilelang/contrib/nvcc.py @@ -0,0 +1,592 @@ +# pylint: disable=invalid-name +# modified from apache tvm python/tvm/contrib/nvcc.py +"""Utility to invoke nvcc compiler in the system""" + +from __future__ import annotations + +import os +import subprocess +import warnings +import contextlib +from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH +import shutil +import tempfile +import tvm_ffi +from tilelang import tvm as tvm +from tvm.target import Target + +from tvm.base import py_str +from tvm.contrib import utils + + +def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, verbose=False): + """Compile cuda code with NVCC from env. + + Parameters + ---------- + code : str + The cuda code. + + target_format : str + The target format of nvcc compiler. + + arch : str + The cuda architecture. + + options : str or list of str + The additional options. + + path_target : str, optional + Output file. + + Return + ------ + cubin : bytearray + The bytearray of the cubin + """ + if arch is None: + # If None, then it will use `tvm.target.Target.current().arch`. + # Target arch could be a str like "sm_xx", or a list, such as + # [ + # "-gencode", "arch=compute_52,code=sm_52", + # "-gencode", "arch=compute_70,code=sm_70" + # ] + compute_version = get_target_compute_version(Target.current(allow_none=True)) + target_arch = get_target_arch(compute_version) + arch = ["-gencode", f"arch=compute_{target_arch},code=sm_{target_arch}"] + + temp = utils.tempdir() + file_name = "tvm_kernels" + if target_format not in ["cubin", "ptx", "fatbin"]: + raise ValueError("target_format must be in cubin, ptx, fatbin") + temp_code = temp.relpath(f"{file_name}.cu") + temp_target = temp.relpath(f"{file_name}.{target_format}") + + pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() + kernels_output_dir = pass_context.config.get("cuda.kernels_output_dir", None) + if kernels_output_dir is not None: + if not os.path.isdir(kernels_output_dir): + os.makedirs(kernels_output_dir) + temp_code = os.path.join(kernels_output_dir, f"{file_name}.cu") + temp_target = os.path.join(kernels_output_dir, f"{file_name}.{target_format}") + + with open(temp_code, "w") as out_file: + out_file.write(code) + + file_target = path_target if path_target else temp_target + cmd = [get_nvcc_compiler()] + cmd += [f"--{target_format}", "-O3"] + # Always include line info for better profiling and mapping + cmd += ["-lineinfo"] + if isinstance(arch, list): + cmd += arch + elif isinstance(arch, str): + cmd += ["-arch", arch] + + if options: + if isinstance(options, str): + cmd += [options] + elif isinstance(options, list): + cmd += options + else: + raise ValueError("options must be str or list of str") + + cmd += ["-o", file_target] + cmd += [temp_code] + + # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler + # just in case it is not in the path. On Windows it is not in the path by default. + # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. + # Because it is hard to do runtime compiler detection, we require nvcc is configured + # correctly by default. + # if cxx_compiler_path != "": + # cmd += ["-ccbin", cxx_compiler_path] + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + (out, _) = proc.communicate() + + if verbose: + print(py_str(out)) + + if proc.returncode != 0: + msg = f"{code}\nCompilation error:\n{py_str(out)}\nCommand: {' '.join(cmd)}\n" + raise RuntimeError(msg) + + with open(file_target, "rb") as f: + data = bytearray(f.read()) + if not data: + raise RuntimeError("Compilation error: empty result is generated") + return data + + +def default_compile_options(compile_flags: list[str] | None = None) -> list[str]: + """ + Build a set of default NVCC compile options for TileLang generated sources. + + Includes C++ standard and common include paths (TileLang templates, CUTLASS, + CUDA include). Merges user-provided compile flags if given. + + Parameters + ---------- + compile_flags : Optional[List[str]] + Additional flags to include. Items are split on whitespace. + + Returns + ------- + List[str] + A list of flags suitable for NVCC's command line. + """ + options: list[str] = ["-std=c++17"] + try: + if TILELANG_TEMPLATE_PATH: + options.append(f"-I{TILELANG_TEMPLATE_PATH}") + except Exception: + pass + try: + if CUTLASS_INCLUDE_DIR: + options.append(f"-I{CUTLASS_INCLUDE_DIR}") + except Exception: + pass + try: + if CUDA_HOME: + options.append(f"-I{os.path.join(CUDA_HOME, 'include')}") + except Exception: + pass + + # Preserve user flags exactly, including repeated tokens required by NVCC + # (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries). + if compile_flags: + import shlex + + for flag in compile_flags: + # Split each string like a shell would, preserving quoted args + tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)] + options.extend(tokens) + return options + + +def get_ptx_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str: + """ + Compile CUDA C++ source to PTX using NVCC and return as text. + + Parameters + ---------- + code : str + CUDA C++ kernel source code. + compile_flags : Optional[List[str]] + Additional flags merged with defaults. + verbose : bool + Print NVCC output when True. + + Returns + ------- + str + PTX text. + """ + opts = default_compile_options(compile_flags) + ptx_bytes = compile_cuda(code, target_format="ptx", options=opts, verbose=verbose) + try: + return ptx_bytes.decode("utf-8") + except Exception: + return str(ptx_bytes) + + +def _find_tool(name: str) -> str | None: + """Find a CUDA binary in PATH or under CUDA_HOME/bin.""" + path = shutil.which(name) + if path: + return path + if CUDA_HOME: + candidate = os.path.join(CUDA_HOME, "bin", name) + if os.path.exists(candidate): + return candidate + return None + + +def get_sass_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str: + """ + Compile CUDA C++ source to CUBIN and disassemble to SASS. + + Uses nvdisasm if available; otherwise falls back to cuobjdump. + + Parameters + ---------- + code : str + CUDA C++ kernel source code. + compile_flags : Optional[List[str]] + Additional flags merged with defaults. + verbose : bool + Print tool outputs when True. + + Returns + ------- + str + SASS text. + """ + opts = default_compile_options(compile_flags) + cubin_bytes = compile_cuda(code, target_format="cubin", options=opts, verbose=verbose) + + # Write to a temp .cubin file + with tempfile.NamedTemporaryFile(suffix=".cubin", delete=False) as tmp: + tmp.write(cubin_bytes) + cubin_path = tmp.name + + # Try disassembly tools (prefer nvdisasm, fallback cuobjdump) + cand_nvdisasm = _find_tool("nvdisasm") + cand_cuobjdump = _find_tool("cuobjdump") + if not cand_nvdisasm and not cand_cuobjdump: + raise RuntimeError("Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH.") + last_err: str | None = None + try: + # Attempt nvdisasm first + tools_to_try = [] + if cand_nvdisasm: + tools_to_try.append(("nvdisasm", [cand_nvdisasm, cubin_path])) + if cand_cuobjdump: + tools_to_try.append(("cuobjdump", [cand_cuobjdump, "--dump-sass", cubin_path])) + + for tool_name, cmd in tools_to_try: + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + out, _ = proc.communicate() + text = py_str(out) + if verbose: + print(f"[{tool_name}] output:\n{text}") + if proc.returncode == 0 and text.strip(): + return text + last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}" + # If we reach here, all attempts failed + raise RuntimeError(f"SASS disassembly failed. Tried tools: {', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}") + finally: + with contextlib.suppress(Exception): + os.remove(cubin_path) + + +def find_cuda_path(): + """Utility function to find cuda path + + Returns + ------- + path : str + Path to cuda root. + """ + if CUDA_HOME: + return CUDA_HOME + raise RuntimeError( + "Failed to automatically detect CUDA installation. Please set the CUDA_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda)." + ) + + +def get_cuda_version(cuda_path=None): + """Utility function to get cuda version + + Parameters + ---------- + cuda_path : Optional[str] + + Path to cuda root. If None is passed, will use + `find_cuda_path()` as default. + + Returns + ------- + version : float + The cuda version + + """ + if cuda_path is None: + cuda_path = find_cuda_path() + + version_file_path = os.path.join(cuda_path, "version.txt") + if not os.path.exists(version_file_path): + # Debian/Ubuntu repackaged CUDA path + version_file_path = os.path.join(cuda_path, "lib", "cuda", "version.txt") + try: + with open(version_file_path) as f: + version_str = f.read().strip().split()[-1] + return tuple(int(field) for field in version_str.split(".")) + except FileNotFoundError: + pass + + cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + out = py_str(out) + if proc.returncode == 0: + release_line = [l for l in out.split("\n") if "release" in l][0] + release_fields = [s.strip() for s in release_line.split(",")] + version_str = [f[1:] for f in release_fields if f.startswith("V")][0] + return tuple(int(field) for field in version_str.split(".")) + raise RuntimeError("Cannot read cuda version file") + + +@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True) +def find_libdevice_path(arch): + """Utility function to find libdevice + + Parameters + ---------- + arch : int + The compute architecture in int + + Returns + ------- + path : str + Path to libdevice. + """ + cuda_path = find_cuda_path() + lib_path = os.path.join(cuda_path, "nvvm/libdevice") + if not os.path.exists(lib_path): + # Debian/Ubuntu repackaged CUDA path + lib_path = os.path.join(cuda_path, "lib/nvidia-cuda-toolkit/libdevice") + selected_ver = 0 + selected_path = None + cuda_ver = get_cuda_version(cuda_path) + major_minor = (cuda_ver[0], cuda_ver[1]) + if major_minor in ( + (9, 0), + (9, 1), + (10, 0), + (10, 1), + (10, 2), + (11, 0), + (11, 1), + (11, 2), + (11, 3), + ): + path = os.path.join(lib_path, "libdevice.10.bc") + else: + for fn in os.listdir(lib_path): + if not fn.startswith("libdevice"): + continue + + try: + # expected pattern: libdevice.${ARCH}.10.bc + # e.g., libdevice.compute_20.10.bc + ver = int(fn.split(".")[-3].split("_")[-1]) + if selected_ver < ver <= arch: + selected_ver = ver + selected_path = fn + except ValueError: + # it can just be `libdevice.10.bc` in CUDA 10 + selected_path = fn + + if selected_path is None: + raise RuntimeError(f"Cannot find libdevice for arch {arch}") + path = os.path.join(lib_path, selected_path) + return path + + +def callback_libdevice_path(arch): + try: + return find_libdevice_path(arch) + except RuntimeError: + warnings.warn("Cannot find libdevice path", stacklevel=2) + return "" + + +@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version", override=True) +def get_target_compute_version(target=None): + """Utility function to get compute capability of compilation target. + + Looks for the target arch in three different places, first in the target input, then the + Target.current() scope, and finally the GPU device (if it exists). + + Parameters + ---------- + target : tvm.target.Target, optional + The compilation target + + Returns + ------- + compute_version : str + compute capability of a GPU (e.g. "8.6" or "9.0") + """ + # 1. input target object + # 2. Target.current() + target = target or Target.current() + if target and target.arch: + arch = target.arch.split("_")[1] + if len(arch) == 2: + major, minor = arch + # Handle old format like sm_89 + return major + "." + minor + elif len(arch) == 3: + major = int(arch[0]) + if major < 2: + major = arch[0:2] + minor = arch[2] + return major + "." + minor + else: + # This is for arch like "sm_90a" + major, minor, suffix = arch + return major + "." + minor + "." + suffix + + # 3. GPU compute version + if tvm.cuda(0).exist: + return tvm.cuda(0).compute_version + + raise ValueError("No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target.") + + +def parse_compute_version(compute_version) -> tuple[int, int]: + """Parse compute capability string to divide major and minor version + + Parameters + ---------- + compute_version : str + compute capability of a GPU (e.g. "6.0") + + Returns + ------- + major : int + major version number + minor : int + minor version number + """ + split_ver = compute_version.split(".") + try: + major = int(split_ver[0]) + minor = int(split_ver[1]) + return major, minor + except (IndexError, ValueError) as err: + # pylint: disable=raise-missing-from + raise RuntimeError("Compute version parsing error") from err + + +def get_target_arch(compute_version) -> str: + major, minor = parse_compute_version(compute_version) + target_arch = str(major * 10 + minor) + if major >= 9: + target_arch += "a" + return target_arch + + +def have_fp16(compute_version): + """Either fp16 support is provided in the compute capability or not + + Parameters + ---------- + compute_version: str + compute capability of a GPU (e.g. "6.0") + """ + major, minor = parse_compute_version(compute_version) + # fp 16 support in reference to: + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions + conditions = [False] + conditions.append(major == 5 and minor >= 3) + conditions.append(major >= 6) + return any(conditions) + + +def have_int8(compute_version): + """Either int8 support is provided in the compute capability or not + + Parameters + ---------- + compute_version : str + compute capability of a GPU (e.g. "6.1") + """ + major, _ = parse_compute_version(compute_version) + return major >= 6 + + +def have_tensorcore(compute_version=None, target=None): + """Either TensorCore support is provided in the compute capability or not + + Parameters + ---------- + compute_version : str, optional + compute capability of a GPU (e.g. "7.0"). + + target : tvm.target.Target, optional + The compilation target, will be used to determine arch if compute_version + isn't specified. + """ + if compute_version is None: + if tvm.cuda(0).exist: + compute_version = tvm.cuda(0).compute_version + else: + if target is None or "arch" not in target.attrs: + warnings.warn( + "Tensorcore will be disabled due to no CUDA architecture specified." + "Try specifying it by adding '-arch=sm_xx' to your target.", + stacklevel=2, + ) + return False + compute_version = target.attrs["arch"] + # Compute version will be in the form "sm_{major}{minor}" + major, minor = compute_version.split("_")[1] + compute_version = major + "." + minor + major, _ = parse_compute_version(compute_version) + return major >= 7 + + +def have_cudagraph(): + """Either CUDA Graph support is provided""" + try: + cuda_ver = get_cuda_version() + return not cuda_ver < (10, 0) + except RuntimeError: + return False + + +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16", override=True) +def have_bf16(compute_version): + """Either bf16 support is provided in the compute capability or not + + Parameters + ---------- + compute_version : str + compute capability of a GPU (e.g. "8.0") + """ + major, _ = parse_compute_version(compute_version) + return major >= 8 + + +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8", override=True) +def have_fp8(compute_version): + """Whether fp8 support is provided in the specified compute capability or not + + Parameters + ---------- + compute_version : str + GPU capability + """ + major, minor = parse_compute_version(compute_version) + # fp8 is supported in Ada Lovelace (8.9) or later architectures. + conditions = [False] + conditions.append(major == 8 and minor >= 9) + conditions.append(major >= 9) + return any(conditions) + + +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_tma", override=True) +def have_tma(target): + """Whether TMA support is provided in the specified compute capability or not + + Parameters + ---------- + target : tvm.target.Target + The compilation target + """ + if target.kind.name != "cuda": + return False + compute_version = get_target_compute_version(target) + major, minor = parse_compute_version(compute_version) + # TMA is supported in Ada Lovelace (9.0) or later architectures. + conditions = [False] + conditions.append(major >= 9) + return any(conditions) + + +def is_hopper(target): + if target.kind.name != "cuda": + return False + compute_version = get_target_compute_version(target) + major, minor = parse_compute_version(compute_version) + return major == 9 and minor == 0 + + +def get_nvcc_compiler() -> str: + """Get the path to the nvcc compiler""" + return os.path.join(find_cuda_path(), "bin", "nvcc") diff --git a/tilelang/original/tilelang/contrib/nvrtc.py b/tilelang/original/tilelang/contrib/nvrtc.py new file mode 100644 index 0000000000000000000000000000000000000000..105c518198b66aa1816d67b006d66c4d11b8f3cc --- /dev/null +++ b/tilelang/original/tilelang/contrib/nvrtc.py @@ -0,0 +1,110 @@ +from __future__ import annotations +import cuda.bindings.nvrtc as nvrtc +from typing import Literal +from tvm.target import Target +from .nvcc import get_target_compute_version, parse_compute_version + + +def get_nvrtc_version() -> tuple[int, int]: + result, major, minor = nvrtc.nvrtcVersion() + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}" + return (major, minor) + + +def compile_cuda( + code: str, + target_format: Literal["ptx", "cubin"] = "ptx", + arch: int | None = None, + options: str | list[str] | None = None, + verbose: bool = False, +) -> bytearray: + """Compile cuda code with NVRTC. + + Parameters + ---------- + code : str + The cuda code. + + target_format : Literal["ptx", "cubin"] + The target format of nvrtc compiler. + + arch : Optional[int] + The cuda architecture code. + + options : Optional[Union[str, List[str]]] + The additional options. + + verbose : bool + Whether to print the verbose output. + + Return + ------ + result_bytes : bytearray + The bytearray of the cubin or ptx code. + """ + if arch is None: + # If None, then it will use `tvm.target.Target.current().arch`. + # Target arch could be a str like "80", "90", "90a", etc. + major, minor = parse_compute_version(get_target_compute_version(Target.current(allow_none=True))) + arch = major * 10 + minor + prefix = "compute" if target_format == "ptx" else "sm" + suffix = "a" if arch >= 90 else "" + arch_option = f"--gpu-architecture={prefix}_{arch}{suffix}" + + file_name = "tvm_kernels" + if target_format not in ["cubin", "ptx"]: + raise ValueError("target_format must be cubin or ptx") + + final_options = ["-default-device"] + if get_nvrtc_version() >= (12, 8): + final_options += ["-pch"] + if arch is not None: + final_options += [arch_option] + + if options: + if isinstance(options, str): + final_options += [options] + elif isinstance(options, list): + final_options += options + else: + raise ValueError("options must be str or list of str") + + code = "#include \n" + code + code_bytes = bytes(code, "utf-8") + result, program = nvrtc.nvrtcCreateProgram(code_bytes, bytes(file_name, "utf-8"), 0, [], []) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to create program: {result}" + + options_bytes = [bytes(flag, "utf-8") for flag in final_options] + compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0] + + if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + msg = f"{code}\nCompilation error:\n" + if verbose: + result, log_size = nvrtc.nvrtcGetProgramLogSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}" + log_bytes = bytes(log_size) + result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log: {result}" + msg += f"{log_bytes.decode('utf-8')}\n" + else: + msg += "Turn on verbose to see the full compilation log." + msg += f"Options: {' '.join(final_options)}\n" + raise RuntimeError(msg) + + if target_format == "cubin": + result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get CUBIN size: {result}" + result_bytes = bytes(cubin_size) + result = nvrtc.nvrtcGetCUBIN(program, result_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get CUBIN: {result}" + else: + result, ptx_size = nvrtc.nvrtcGetPTXSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX size: {result}" + result_bytes = bytes(ptx_size) + result = nvrtc.nvrtcGetPTX(program, result_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}" + + # Destroy handler + assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}" + + return result_bytes diff --git a/tilelang/original/tilelang/contrib/rocm.py b/tilelang/original/tilelang/contrib/rocm.py new file mode 100644 index 0000000000000000000000000000000000000000..eb61328b6fdd38602df695fd806b87ba96f3eac2 --- /dev/null +++ b/tilelang/original/tilelang/contrib/rocm.py @@ -0,0 +1,288 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility for ROCm backend""" + +# ruff: noqa +import re +import subprocess +import os +from os.path import join, exists + +import tvm_ffi +from tvm.base import py_str +import tvm.runtime +import tvm.target + +from tvm.contrib import utils + + +def find_lld(required=True): + """Find ld.lld in system. + + Parameters + ---------- + required : bool + Whether it is required, + runtime error will be raised if the compiler is required. + + Returns + ------- + valid_list : list of str + List of possible paths. + + Note + ---- + This function will first search ld.lld that + matches the major llvm version that built with tvm + """ + lld_list = [] + major = tvm.target.codegen.llvm_version_major(allow_none=True) + if major is not None: + lld_list += [f"ld.lld-{major}.0"] + lld_list += [f"ld.lld-{major}"] + lld_list += ["ld.lld"] + lld_list += [f"/opt/rocm/llvm/bin/{x}" for x in lld_list] + valid_list = [utils.which(x) for x in lld_list] + valid_list = [x for x in valid_list if x] + if not valid_list and required: + raise RuntimeError("cannot find ld.lld, candidates are: " + str(lld_list)) + return valid_list + + +def rocm_link(in_file, out_file, lld=None): + """Link relocatable ELF object to shared ELF object using lld + + Parameters + ---------- + in_file : str + Input file name (relocatable ELF object file) + + out_file : str + Output file name (shared ELF object file) + + lld : str, optional + The lld linker, if not specified, + we will try to guess the matched clang version. + """ + + # if our result has undefined symbols, it will fail to load + # (hipModuleLoad/hipModuleLoadData), but with a somewhat opaque message + # so we have ld.lld check this here. + # If you get a complaint about missing symbols you might want to check the + # list of bitcode files below. + args = [ + lld if lld is not None else find_lld()[0], + "--no-undefined", + "-shared", + in_file, + "-o", + out_file, + ] + proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = "Linking error using ld.lld:\n" + msg += py_str(out) + raise RuntimeError(msg) + + +@tvm_ffi.register_global_func("tvm_callback_rocm_link", override=True) +def callback_rocm_link(obj_bin): + """Links object file generated from LLVM to HSA Code Object + + Parameters + ---------- + obj_bin : bytearray + The object file + + Return + ------ + cobj_bin : bytearray + The HSA Code Object + """ + tmp_dir = utils.tempdir() + tmp_obj = tmp_dir.relpath("rocm_kernel.o") + tmp_cobj = tmp_dir.relpath("rocm_kernel.co") + with open(tmp_obj, "wb") as out_file: + out_file.write(bytes(obj_bin)) + rocm_link(tmp_obj, tmp_cobj) + cobj_bin = bytearray(open(tmp_cobj, "rb").read()) + return cobj_bin + + +@tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path", override=True) +def callback_rocm_bitcode_path(rocdl_dir=None): + """Utility function to find ROCm device library bitcodes + + Parameters + ---------- + rocdl_dir : str + The path to rocm library directory + The default value is the standard location + """ + # seems link order matters. + + if rocdl_dir is None: + if exists("/opt/rocm/amdgcn/bitcode/"): + rocdl_dir = "/opt/rocm/amdgcn/bitcode/" # starting with rocm 3.9 + else: + rocdl_dir = "/opt/rocm/lib/" # until rocm 3.8 + + bitcode_names = [ + "oclc_daz_opt_on", + "ocml", + "irif", # this does not exist in rocm 3.9, drop eventually + "oclc_correctly_rounded_sqrt_off", + "oclc_correctly_rounded_sqrt_on", + "oclc_daz_opt_off", + "oclc_finite_only_off", + "oclc_finite_only_on", + # todo (t-vi): an alternative might be to scan for the + "oclc_isa_version_803", + "oclc_isa_version_900", # isa version files (if the linker throws out + "oclc_isa_version_906", # the unneeded ones or we filter for the arch we need) + "oclc_isa_version_1030", + "oclc_unsafe_math_off", + "oclc_unsafe_math_on", + "oclc_wavefrontsize64_on", + "oclc_abi_version_500", + ] + + bitcode_files = [] + for n in bitcode_names: + p = join(rocdl_dir, n + ".bc") # rocm >= 3.9 + if not exists(p): # rocm <= 3.8 + p = join(rocdl_dir, n + ".amdgcn.bc") + if exists(p): + bitcode_files.append(p) + elif "isa_version" not in n and n not in {"irif"}: + raise RuntimeError("could not find bitcode " + n) + + return tvm.runtime.convert(bitcode_files) + + +def parse_compute_version(compute_version): + """Parse compute capability string to divide major and minor version + + Parameters + ---------- + compute_version : str + compute capability of a GPU (e.g. "6.0") + + Returns + ------- + major : int + major version number + minor : int + minor version number + """ + split_ver = compute_version.split(".") + try: + major = int(split_ver[0]) + minor = int(split_ver[1]) + return major, minor + except (IndexError, ValueError) as err: + # pylint: disable=raise-missing-from + raise RuntimeError("Compute version parsing error: " + str(err)) + + +def have_matrixcore(compute_version=None): + """Either MatrixCore support is provided in the compute capability or not + + Parameters + ---------- + compute_version : str, optional + compute capability of a GPU (e.g. "7.0"). + + Returns + ------- + have_matrixcore : bool + True if MatrixCore support is provided, False otherwise + """ + if compute_version is None: + if tvm.rocm(0).exist: + compute_version = tvm.rocm(0).compute_version + else: + raise RuntimeError("No ROCm runtime found") + major, _ = parse_compute_version(compute_version) + # matrix core first introduced in 8.0 + if major >= 8: + return True + + return False + + +@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True) +def get_rocm_arch(rocm_path="/opt/dtk"): + # @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True) + # def get_rocm_arch(rocm_path="/opt/dtk"): + """Utility function to get the AMD GPU architecture + + Parameters + ---------- + rocm_path : str + The path to rocm installation directory + + Returns + ------- + gpu_arch : str + The AMD GPU architecture + """ + gpu_arch = "gfx900" + # check if rocm is installed + if not os.path.exists(rocm_path): + print("ROCm not detected, using default gfx900") + return gpu_arch + try: + # Execute rocminfo command + rocminfo_output = subprocess.check_output([f"{rocm_path}/bin/rocminfo"]).decode("utf-8") + + # Use regex to match the "Name" field + match = re.search(r"Name:\s+(gfx\d+[a-zA-Z]*)", rocminfo_output) + if match: + gpu_arch = match.group(1) + return gpu_arch + except subprocess.CalledProcessError: + print( + f"Unable to execute rocminfo command, \ + please ensure ROCm is installed and you have an AMD GPU on your system.\ + using default {gpu_arch}." + ) + return gpu_arch + + +def find_rocm_path(): + """Utility function to find ROCm path + + Returns + ------- + path : str + Path to ROCm root. + """ + if "ROCM_PATH" in os.environ: + return os.environ["ROCM_PATH"] + cmd = ["which", "hipcc"] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + out = out.decode("utf-8").strip() + if proc.returncode == 0: + return os.path.realpath(os.path.join(out, "../..")) + rocm_path = "/opt/rocm" + if os.path.exists(os.path.join(rocm_path, "bin/hipcc")): + return rocm_path + raise RuntimeError("Cannot find ROCm path") diff --git a/tilelang/original/tilelang/engine/__init__.py b/tilelang/original/tilelang/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..476b40a35ef2a62e1931c9c3fa2f8f23e9ec7ff0 --- /dev/null +++ b/tilelang/original/tilelang/engine/__init__.py @@ -0,0 +1,3 @@ +from .lower import lower, is_device_call # noqa: F401 +from .param import KernelParam # noqa: F401 +from .callback import register_cuda_postproc, register_hip_postproc # noqa: F401 diff --git a/tilelang/original/tilelang/engine/callback.py b/tilelang/original/tilelang/engine/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..05fafe9db9a1f176faab63af0db2b9a3da596765 --- /dev/null +++ b/tilelang/original/tilelang/engine/callback.py @@ -0,0 +1,92 @@ +from __future__ import annotations +from typing import Callable +import tvm_ffi +from tvm.target import Target + + +def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = True): + """Register a post-processing function for CUDA code generation. + + Args: + func: A callable that takes generated code (str) and target (Target) as input, + and returns the processed code (str). + override: Whether to override existing registered function. Defaults to True. + """ + tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=func, override=override) + + +def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True): + """Register a post-processing function for HIP code generation. + + Args: + func: A callable that takes generated code (str) and target (Target) as input, + and returns the processed code (str). + override: Whether to override existing registered function. Defaults to True. + """ + tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override) + + +def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True): + """Decorator for registering CUDA post-processing callback function. + + Can be used with or without parentheses: + @register_cuda_postproc_callback + def func(code, target): ... + + @register_cuda_postproc_callback() + def func(code, target): ... + + @register_cuda_postproc_callback(override=False) + def func(code, target): ... + + Args: + func: The function to be decorated or a boolean override flag + override: Whether to override existing registered function. Defaults to True. + """ + if callable(func): + register_cuda_postproc(func, override) + return func + + if func is None or isinstance(func, bool): + _override = func if isinstance(func, bool) else override + + def _register(fn: Callable[[str, Target], str]): + register_cuda_postproc(fn, _override) + return fn + + return _register + + raise TypeError("Invalid decorator usage") + + +def register_hip_postproc_callback(func: Callable | bool = None, override: bool = True): + """Decorator for registering HIP post-processing callback function. + + Can be used with or without parentheses: + @register_hip_postproc_callback + def func(code, target): ... + + @register_hip_postproc_callback() + def func(code, target): ... + + @register_hip_postproc_callback(override=False) + def func(code, target): ... + + Args: + func: The function to be decorated or a boolean override flag + override: Whether to override existing registered function. Defaults to True. + """ + if callable(func): + register_hip_postproc(func, override) + return func + + if func is None or isinstance(func, bool): + _override = func if isinstance(func, bool) else override + + def _register(fn: Callable[[str, Target], str]): + register_hip_postproc(fn, _override) + return fn + + return _register + + raise TypeError("Invalid decorator usage") diff --git a/tilelang/original/tilelang/engine/lower.py b/tilelang/original/tilelang/engine/lower.py new file mode 100644 index 0000000000000000000000000000000000000000..b44e243f055b20018222c7aa0f651e855ae6ee7c --- /dev/null +++ b/tilelang/original/tilelang/engine/lower.py @@ -0,0 +1,286 @@ +"""The compiler for TL programs.""" + +from __future__ import annotations + +import os +import os.path as osp +from typing import Callable +import tilelang.transform +from tilelang import tvm as tvm +from tvm import tir +import tvm_ffi +from tvm.ir import CallingConv +from tvm.target import Target +from tilelang.contrib import hipcc, nvcc +from tilelang.transform import PassConfigKey +from tilelang.utils.deprecated import deprecated_warning +from tilelang.engine.param import KernelParam, CompiledArtifact +from tilelang.utils.target import determine_target +from tilelang.engine.phase import ( + PreLowerSemanticCheck, + LowerAndLegalize, + OptimizeForTarget, +) + + +def is_cpu_device_backend(target: Target): + return target.kind.name == "c" + + +def has_device_kernel_launch(attrs) -> bool: + """Check if the attributes indicate a device kernel launch.""" + return bool(attrs and "calling_conv" in attrs and attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) + + +def is_device_call_c_device(func: tir.PrimFunc): + attrs = func.attrs + calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT) + is_cpacked = calling_conv == CallingConv.C_PACKED_FUNC + + # Check if it's a C target + if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked: + return True + + return has_device_kernel_launch(attrs) + + +def is_device_call(func: tir.PrimFunc): + return has_device_kernel_launch(func.attrs) + + +def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: + return is_device_call_c_device if is_device_c else is_device_call + + +def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: + return lambda func: not get_device_call(is_device_c)(func) + + +@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) +def tilelang_callback_cuda_compile(code, target, pass_config=None): + project_root = osp.join(osp.dirname(__file__), "../..") + if "TL_TEMPLATE_PATH" in os.environ: + tl_template_path = os.environ["TL_TEMPLATE_PATH"] + else: + tl_template_path = osp.abspath(osp.join(project_root, "src")) + # TODO(lei): this indeed should be renamed into + # TL_CUTLASS_INCLUDE_PATH in the future + if "TL_CUTLASS_PATH" in os.environ: + cutlass_path = os.environ["TL_CUTLASS_PATH"] + else: + cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) + target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target)) + + arch = [f"-arch=sm_{target_arch}"] + compile_format = "cubin" + + # Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib + cfg = pass_config or {} + if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH, False): + deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7") + disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH, True)) + enable_fast_math = not disable_fast_math + else: + enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH, False)) + + ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None) + verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)) + + options = [ + "-std=c++17", + "-I" + tl_template_path, + "-I" + cutlass_path, + ] + # Merge extra device compiler flags from pass config, if provided + extra_flags = cfg.get(PassConfigKey.TL_DEVICE_COMPILE_FLAGS, None) + if extra_flags: + import shlex + + if isinstance(extra_flags, str): + tokens = shlex.split(extra_flags) + else: + tokens = [] + for flag in extra_flags: + if isinstance(flag, str): + tokens.extend(shlex.split(flag)) + else: + tokens.append(str(flag)) + options += tokens + + if enable_fast_math: + options.append("--use_fast_math") + if ptxas_usage_level is not None: + options.append(f"--ptxas-options=--register-usage-level={ptxas_usage_level}") + if verbose_ptxas_output: + options.append("--ptxas-options=--verbose") + + ptx = nvcc.compile_cuda( + code, + compile_format, + arch, + options=options, + verbose=False, + ) + + return ptx + + +@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True) +def tilelang_callback_hip_compile(code, target): + project_root = osp.join(osp.dirname(__file__), "../..") + tl_template_path = osp.abspath(osp.join(project_root, "src")) + + # TODO(lei): actually this indeed should be renamed into + # TL_COMPOSABLE_KERNEL_INCLUDE_PATH in the future + if "TL_COMPOSABLE_KERNEL_PATH" in os.environ: + ck_path = os.environ["TL_COMPOSABLE_KERNEL_PATH"] + else: + ck_path = osp.abspath(osp.join(project_root, "3rdparty/composable_kernel/include")) + + hsaco = hipcc.compile_hip( + code, + target_format="hsaco", + options=[ + "-std=c++17", + "-O1", + "-I" + tl_template_path, + "-I" + ck_path, + ], + verbose=False, + ) + + return hsaco + + +def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: + tensor_types = [] + for var in func.params: + if var in func.buffer_map: + tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) + else: + if var.dtype == "handle": + raise ValueError( + f"Handle parameter {var} must be mapped to a buffer.\n" + f"Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer." + ) + tensor_types.append(KernelParam.from_var(var)) + return tensor_types + + +def canon_target_host(target: str | Target, target_host: str | Target | None): + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" + + return target_host + + +def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: + host_mod = tir.transform.BindTarget(target_host)(host_mod) + host_mod = tir.transform.FP8StorageLegalize()(host_mod) + host_mod = tir.transform.BF16StorageLegalize()(host_mod) + host_mod = tir.transform.LowerTVMBuiltin()(host_mod) + host_mod = tir.transform.LowerCustomDatatypes()(host_mod) + host_mod = tilelang.transform.LowerIntrin()(host_mod) + host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod) + host_mod = tir.transform.CombineContextCall()(host_mod) + if target_host.kind.name == "llvm": + host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host) + elif target_host.kind.name == "c": + host_mod = tvm.ffi.get_global_func("target.build.tilelang_c")(host_mod, target_host) + else: + raise ValueError(f"Target host {target_host.kind.name} is not supported") + return host_mod + + +def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: + device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod) + device_mod = tilelang.transform.LowerIntrin()(device_mod) + device_mod = tir.transform.Simplify()(device_mod) + + if target.kind.name == "cuda": + global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target) + elif target.kind.name == "hip": + device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) + else: + raise ValueError(f"Target {target.kind.name} is not supported") + + return device_mod + + +def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: + device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod) + device_mod = tilelang.transform.LowerIntrin()(device_mod) + device_mod = tir.transform.Simplify()(device_mod) + if target.kind.name == "cuda": + global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile" + device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target) + elif target.kind.name == "hip": + device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) + elif target.kind.name == "c": + device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) + elif target.kind.name == "llvm": + device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target) + elif target.kind.name == "webgpu": + device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target) + elif target.kind.name == "metal": + device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) + else: + raise ValueError(f"Target {target.kind.name} is not supported") + + return device_mod + + +def lower( + func_or_mod: tir.PrimFunc | tvm.IRModule, + target: str | Target = "auto", + target_host: str | Target | None = None, + runtime_only=False, + enable_host_codegen=False, + enable_device_compile=False, +) -> CompiledArtifact: + """ + enable_host_codegen: whether to enable host codegen, default is False, as we have our + own host codegen implementation in jit. + enable_device_compile: whether to enable device codegen, default is False, as we have our + own device codegen implementation in jit. + """ + + mod = func_or_mod + params = None + if isinstance(func_or_mod, tir.PrimFunc): + func = func_or_mod + params = extrac_params(func) if not runtime_only else None + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + + if isinstance(target, str): + target = determine_target(target) + + target_host = canon_target_host(target, target_host) + + target_host = tvm.target.Target.canon_target(target_host) + target = tvm.target.Target(target, target_host) + + _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) + _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) + + # Before lowering, do semantic check + PreLowerSemanticCheck(mod) + + # Phase 1: Lower and legalize the IR + mod = LowerAndLegalize(mod, target) + + # Phase 2: Optimize the IR for the target + mod = OptimizeForTarget(mod, target) + + host_mod = tir.transform.Filter(_is_host_call)(mod) + device_mod = tir.transform.Filter(_is_device_call)(mod) + + codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target) + + if enable_host_codegen: + host_mod = host_codegen(host_mod, target_host) + host_mod.import_module(codegen_mod) + return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) + + return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source()) diff --git a/tilelang/original/tilelang/engine/param.py b/tilelang/original/tilelang/engine/param.py new file mode 100644 index 0000000000000000000000000000000000000000..fe023f83fa4202b7a7ab8bb4220de89318f7cfd8 --- /dev/null +++ b/tilelang/original/tilelang/engine/param.py @@ -0,0 +1,155 @@ +"""The profiler and convert to torch utils""" + +from __future__ import annotations + +from dataclasses import dataclass +import torch +from tilelang import tvm as tvm +from tvm.tir import Buffer, IntImm, Var, PrimExpr +import tilelang.language as T + + +@dataclass +class KernelParam: + """ + Represents parameters for a kernel operation, storing dtype and shape information. + Used to describe tensor or scalar parameters in TVM/PyTorch interop. + """ + + # Use tvm.DataType (buffer.dtype) directly instead of torch.dtype to support more data types + # tvm.DataType can represent a much wider range of types than PyTorch's dtype system, + # including specialized types like float8_e4m3, float8_e5m2, custom quantized types, etc. + # This avoids information loss when converting from TVM buffer types + dtype: tvm.DataType # Data type from buffer.dtype (supports all TVM types) + shape: list[int | Var] # List of dimensions, can be integers or TVM variables + + @classmethod + def from_buffer(cls, buffer: Buffer): + """ + Creates a KernelParam instance from a TVM Buffer object. + + Args: + buffer: TVM Buffer object containing dtype and shape information + + Returns: + KernelParam instance with dtype directly from buffer and shape + + Raises: + ValueError: If dimension type is not supported (not IntImm or Var) + """ + # Use buffer.dtype directly (tvm.DataType) to preserve all type information + # buffer.dtype is already a tvm.DataType object, no conversion needed + dtype = buffer.dtype + shape = [] + for s in buffer.shape: + if isinstance(s, IntImm): + shape.append(s.value) + elif isinstance(s, (Var, PrimExpr)): + shape.append(s) + else: + raise ValueError(f"Unsupported dimension type: {type(s)} {s}") + return cls(dtype, shape) + + @classmethod + def from_var(cls, var: Var): + """ + Creates a KernelParam instance from a TVM Variable object. + Used for scalar parameters. + + Args: + var: TVM Variable object containing dtype information + + Returns: + KernelParam instance representing a scalar (empty shape) + """ + # Use var.dtype directly (tvm.DataType) to preserve all type information + # var.dtype is already a tvm.DataType object, no conversion needed + dtype = var.dtype + return cls(dtype, []) + + def is_scalar(self) -> bool: + """ + Checks if the parameter represents a scalar value. + + Returns: + bool: True if parameter has no dimensions (empty shape), False otherwise + """ + return len(self.shape) == 0 + + def is_unsigned(self) -> bool: + """ + Checks if the parameter represents an unsigned integer type. + + Returns: + bool: True if parameter is an unsigned integer type, False otherwise + """ + dtype_str = str(self.dtype) + if dtype_str.startswith("torch."): + dtype_str = dtype_str[6:] + return dtype_str.startswith("uint") + + def is_float8(self) -> bool: + """ + Checks if the parameter represents a float8 type. + + Returns: + bool: True if parameter is a float8 type, False otherwise + """ + dtype_str = str(self.dtype) + if dtype_str.startswith("torch."): + dtype_str = dtype_str[6:] + return dtype_str.startswith("float8") + + def is_float4(self) -> bool: + """ + Checks if the parameter represents a float4 type. + + Returns: + bool: True if parameter is a float4 type, False otherwise + """ + dtype_str = str(self.dtype) + if dtype_str.startswith("torch."): + dtype_str = dtype_str[6:] + return dtype_str.startswith("float4") + + def is_boolean(self) -> bool: + """ + Checks if the parameter represents a boolean type. + + Returns: + bool: True if parameter is a boolean type, False otherwise + """ + dtype_str = str(self.dtype) + if dtype_str.startswith("torch."): + dtype_str = dtype_str[6:] + return dtype_str.startswith("bool") + + def torch_dtype(self) -> torch.dtype: + """ + Converts the TVM DataType to PyTorch dtype. + + This method is used when creating PyTorch tensors from KernelParam, + as PyTorch's tensor creation functions require torch.dtype. + + Returns: + torch.dtype: Corresponding PyTorch dtype + + Example: + >>> param = KernelParam.from_buffer(buffer) + >>> tensor = torch.empty(shape, dtype=param.torch_dtype()) + """ + return T.dtype(self.dtype).as_torch() + + +@dataclass +class CompiledArtifact: + """ + Represents a compiled kernel artifact containing both host and device code. + Stores all necessary components for kernel execution in the TVM runtime. + """ + + host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution + device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code + params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel + kernel_source: str # Raw source code of the generated kernel + rt_mod: tvm.runtime.Module | None = None # Runtime module for execution, may be lazily initialized diff --git a/tilelang/original/tilelang/engine/phase.py b/tilelang/original/tilelang/engine/phase.py new file mode 100644 index 0000000000000000000000000000000000000000..0e72c837e945cb82a2d2d7b35afa7594ce9c29eb --- /dev/null +++ b/tilelang/original/tilelang/engine/phase.py @@ -0,0 +1,274 @@ +from __future__ import annotations +from tvm import tir, IRModule +from tvm.target import Target +import tilelang +from tilelang.transform import PassContext +from tilelang.contrib.nvcc import have_tma, is_hopper + + +def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: + # avoid circular import + from tilelang.jit.adapter.utils import is_cuda_target + + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + if (not is_cuda_target(target)) or (not have_tma(target)): + return False + disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False) + return not disable_warp_specialized + + +def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + if not have_tma(target): + return False + disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) + return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target) + + +def allow_fence_proxy(target: Target | None = None) -> bool: + return have_tma(target) + + +def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False) + return not disable_vectorize + + +def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False) + return enable_global_thread_sync + + +def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enable_aggressive_merge = bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) + if allow_warp_specialized(pass_ctx=pass_ctx, target=target): + # This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass + # when warp specialization is enabled, as different warp threads may access different + # buffers, but the liveness analysis is hard because we need to do pipeline. + enable_aggressive_merge = False + return enable_aggressive_merge + + +def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) + + +def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False) + return enabled + + +def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") + if not formats_value: + return ["txt"] + + formats_str = formats_value.strip().lower() + valid_formats = ["txt", "png", "pdf", "svg", "all"] + + if formats_str == "all": + return ["txt", "png", "pdf", "svg"] + + if "," in formats_str: + formats_list = [f.strip() for f in formats_str.split(",")] + else: + formats_list = [formats_str] + + invalid_formats = [f for f in formats_list if f not in valid_formats] + if invalid_formats: + raise ValueError( + f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. " + f"Valid formats are: {valid_formats}. " + f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')" + ) + return formats_list + + +def LayoutVisual(mod: IRModule) -> None: + """Apply layout visualization pass if enabled.""" + if should_enable_layout_visual(): + formats = get_layout_visual_formats() + tilelang.analysis.LayoutVisual(formats=formats)(mod) + + +def PreLowerSemanticCheck(mod: IRModule) -> None: + """ + Check whether the module is valid before lowering. If not, raise a user-friendly error + in Python side instead of letting the error dive into the complicated TVM/C++ stack. + Note: This is a validation-only pipeline of passes and does not modify or return the module. + """ + + # Debug + # tilelang.analysis.ASTPrinter()(mod) + # Check if there are any invalid nested loops. + tilelang.analysis.NestedLoopChecker()(mod) + # Check if there are any invalid symbolic T.Parallel + fragment access. + tilelang.analysis.FragmentLoopChecker()(mod) + + +def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: + # Bind the target device information to the module + """ + Bind target information and progressively legalize and lower frontend Tile IR into a form suitable for downstream optimization and codegen. + + This pass pipeline: + - Binds the provided target to the module. + - Legalizes frontend Tile IR into TVM-compatible constructs. + - Simplifies expressions. + - Configures reducer layouts and performs layout inference for fragments and shared memory. + - Lowers high-level tile operations and L2 persistent maps. + - Legalizes vectorized loops and inserts safety checks for memory accesses. + - Re-simplifies to remove redundancies introduced by safety checks. + - Attempts loop vectorization for dynamic-shaped loops. + + Parameters: + mod (IRModule): The input IR module containing frontend Tile IR. + target (Target): Target device information to bind into the module. + + Returns: + IRModule: The transformed module, ready for target-specific optimization passes. + """ + mod = tir.transform.BindTarget(target)(mod) + + if should_force_let_inline(): + # Force-let inline whenever the pass config requests it. + mod = tilelang.transform.LetInline()(mod) + # Add wrapper for single buf store + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Visualize the layout + LayoutVisual(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) + # Lower l2 persistent map + mod = tilelang.transform.LowerL2Persistent()(mod) + # Legalize vectorized loops to ensure they are valid + mod = tilelang.transform.LegalizeVectorizedLoop()(mod) + # Add safety checks for memory accesses + mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) + # Simplify again to clean up any duplicated conditions + # that may have been introduced by safety checks + # use an enhanced pass to simplify the dynamic symbolics + # TODO(lei): return to tir pass when kSymbolicBound simplification + # is merged into tvm. + mod = tilelang.transform.Simplify()(mod) + # Hoist any root-block annotations to PrimFunc attrs if pass is available + mod = tilelang.transform.HoistNonRestrictParams()(mod) + return mod + + +def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: + pass_ctx = tilelang.transform.get_pass_context() + # Lower the barrier.arrive into specific initialization slot + mod = tilelang.transform.LowerSharedBarrier()(mod) + # Lower the shared.tmem into specific initialization slot + mod = tilelang.transform.LowerSharedTmem()(mod) + # which may be introduced by the LegalizeSafeMemoryAccess + if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): + mod = tilelang.transform.IfStmtBinding()(mod) + mod = tilelang.transform.MultiVersionBuffer()(mod) + mod = tilelang.transform.WarpSpecialized()(mod) + mod = tilelang.transform.InjectTmaBarrier()(mod) + # if tma is not enabled, we can also do pipeline planning + # to get better performance with async copy + mod = tilelang.transform.PipelinePlanning()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) + # warp_specialized pass will pack the if stmt into the block + # so we need to lower the opaque block first + mod = tilelang.transform.LowerOpaqueBlock()(mod) + mod = tilelang.transform.MergeIfStmt()(mod) + if is_hopper(target): + mod = tilelang.transform.RewriteWgmmaSync()(mod) + mod = tilelang.transform.InjectFenceProxy()(mod) + else: + mod = tilelang.transform.IfStmtBinding()(mod) + mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tilelang.transform.PipelinePlanning()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) + mod = tilelang.transform.MergeIfStmt()(mod) + if allow_fence_proxy(target=target): + # in hopper device, wgmma is an async proxy + # so we need to inject a fence proxy before it + mod = tilelang.transform.InjectFenceProxy()(mod) + + mod = tilelang.transform.LowerOpaqueBlock()(mod) + mod = tilelang.transform.Simplify()(mod) + mod = tir.transform.NarrowDataType(32)(mod) + mod = tilelang.transform.FlattenBuffer()(mod) + # ConfigIndexBitwidth must be applied after FlattenBuffer + # as it will flatten index computing + mod = tilelang.transform.ConfigIndexBitwidth()(mod) + mod = tir.transform.Simplify()(mod) + mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) + mod = tilelang.transform.StorageRewrite()(mod) + mod = tir.transform.UnrollLoop()(mod) + mod = tir.transform.RenormalizeSplitPattern()(mod) + mod = tir.transform.Simplify()(mod) + mod = tir.transform.RemoveNoOp()(mod) + mod = tir.transform.RewriteUnsafeSelect()(mod) + mod = tir.transform.HoistIfThenElse()(mod) + + mod = tir.transform.VerifyMemory()(mod) + mod = tir.transform.AnnotateEntryFunc()(mod) + # TODO(lei): This is a hack to make sure the + # thread level allreduce pass can be applied + # in TL. As Tl only use one thread dimension + # the var binding information will be lost + # in the lowering process with Legalization + # and Simplify pass. + # We can find a way better to create var instead + # of putting the LowerThreadAllreduce before + # the Legalization. + mod = tir.transform.InferFragment()(mod) + mod = tilelang.transform.LowerThreadAllreduce()(mod) + + mod = tilelang.transform.LowerHopperIntrin()(mod) + # Global Barrier Synchronization must be applied before + # SplitHostDevice pass, as the global barrier + if allow_global_thread_synchronization(): + mod = tilelang.transform.ThreadSync("global")(mod) + mod = tilelang.transform.AnnotateDeviceRegions()(mod) + mod = tilelang.transform.SplitHostDevice()(mod) + mod = tilelang.transform.AnnotateReadOnlyParams()(mod) + # MergeSharedMemoryAllocations must be applied after SplitHostDevice + # because the merged allocation site is at the beginning of each device function + enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) + mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) + mod = tilelang.transform.ThreadSync("shared")(mod) + mod = tilelang.transform.ThreadSync("shared.dyn")(mod) + # Inject PTX async copy must behind the thread sync pass + # as ptx async copy won't be recognized as a valid buffer load + mod = tilelang.transform.InjectPTXAsyncCopy()(mod) + if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): + mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) + mod = tilelang.transform.MakePackedAPI()(mod) + mod = tilelang.transform.Simplify()(mod) + mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) + + # Transform threadblock to persistent threadblock + mod = tilelang.transform.PersistThreadblock()(mod) + + return mod diff --git a/tilelang/original/tilelang/env.py b/tilelang/original/tilelang/env.py new file mode 100644 index 0000000000000000000000000000000000000000..dee07f1ad8d1979ee53154aed6d5d5dd7b3669b3 --- /dev/null +++ b/tilelang/original/tilelang/env.py @@ -0,0 +1,356 @@ +from __future__ import annotations +import sys +import os +import pathlib +import logging +import shutil +import glob +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# SETUP ENVIRONMENT VARIABLES +CUTLASS_NOT_FOUND_MESSAGE = "CUTLASS is not installed or found in the expected path" +", which may lead to compilation bugs when utilize tilelang backend." +COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = "Composable Kernel is not installed or found in the expected path" +", which may lead to compilation bugs when utilize tilelang backend." +TL_TEMPLATE_NOT_FOUND_MESSAGE = "TileLang is not installed or found in the expected path" +", which may lead to compilation bugs when utilize tilelang backend." +TVM_LIBRARY_NOT_FOUND_MESSAGE = "TVM is not installed or found in the expected path" + +TL_ROOT = os.path.dirname(os.path.abspath(__file__)) +# Only expose the internal lib directory to sys.path to avoid shadowing +# common top-level module names (e.g., utils, analysis) from user projects. +TL_LIBS = [os.path.join(TL_ROOT, "lib")] +TL_LIBS = [i for i in TL_LIBS if os.path.exists(i)] + +DEV = False +THIRD_PARTY_ROOT = os.path.join(TL_ROOT, "3rdparty") +if not os.path.exists(THIRD_PARTY_ROOT): + DEV = True + tl_dev_root = os.path.dirname(TL_ROOT) + + dev_lib_root = os.path.join(tl_dev_root, "build") + # In dev builds, place artifacts under build/lib and point search path there + # to avoid adding the entire build root to sys.path. + TL_LIBS = [os.path.join(dev_lib_root, "lib"), os.path.join(dev_lib_root, "tvm")] + THIRD_PARTY_ROOT = os.path.join(tl_dev_root, "3rdparty") + logger.warning(f"Loading tilelang libs from dev root: {dev_lib_root}") + +assert TL_LIBS and all(os.path.exists(i) for i in TL_LIBS), f"tilelang lib root do not exists: {TL_LIBS}" + +for lib in TL_LIBS: + if lib not in sys.path: + sys.path.insert(0, lib) + + +def _find_cuda_home() -> str: + """Find the CUDA install path. + + Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py + """ + # Guess #1 + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home is None: + # Guess #2 + nvcc_path = shutil.which("nvcc") + if nvcc_path is not None: + # Standard CUDA pattern + if "cuda" in nvcc_path.lower(): + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + # NVIDIA HPC SDK pattern + elif "hpc_sdk" in nvcc_path.lower(): + # Navigate to the root directory of nvhpc + cuda_home = os.path.dirname(os.path.dirname(os.path.dirname(nvcc_path))) + # Generic fallback for non-standard or symlinked installs + else: + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + + else: + # Guess #3 + if sys.platform == "win32": + cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") + cuda_home = "" if len(cuda_homes) == 0 else cuda_homes[0] + else: + # Linux/macOS + if os.path.exists("/usr/local/cuda"): + cuda_home = "/usr/local/cuda" + elif os.path.exists("/opt/nvidia/hpc_sdk/Linux_x86_64"): + cuda_home = "/opt/nvidia/hpc_sdk/Linux_x86_64" + + # Validate found path + if cuda_home is None or not os.path.exists(cuda_home): + cuda_home = None + + return cuda_home if cuda_home is not None else "" + + +def _find_rocm_home() -> str: + """Find the ROCM install path.""" + rocm_home = os.environ.get("ROCM_PATH") or os.environ.get("ROCM_HOME") + if rocm_home is None: + rocmcc_path = shutil.which("hipcc") + if rocmcc_path is not None: + rocm_home = os.path.dirname(os.path.dirname(rocmcc_path)) + else: + rocm_home = "/opt/rocm" + if not os.path.exists(rocm_home): + rocm_home = None + return rocm_home if rocm_home is not None else "" + + +# Cache control +class CacheState: + """Class to manage global kernel caching state.""" + + _enabled = True + + @classmethod + def enable(cls): + """Enable kernel caching globally.""" + cls._enabled = True + + @classmethod + def disable(cls): + """Disable kernel caching globally.""" + cls._enabled = False + + @classmethod + def is_enabled(cls) -> bool: + """Return current cache state.""" + return cls._enabled + + +@dataclass +class EnvVar: + """ + Descriptor for managing access to a single environment variable. + + Purpose + ------- + In many projects, access to environment variables is scattered across the codebase: + * `os.environ.get(...)` calls are repeated everywhere + * Default values are hard-coded in multiple places + * Overriding env vars for tests/debugging is messy + * There's no central place to see all environment variables a package uses + + This descriptor solves those issues by: + 1. Centralizing the definition of the variable's **key** and **default value** + 2. Allowing *dynamic* reads from `os.environ` so changes take effect immediately + 3. Supporting **forced overrides** at runtime (for unit tests or debugging) + 4. Logging a warning when a forced value is used (helps detect unexpected overrides) + 5. Optionally syncing forced values back to `os.environ` if global consistency is desired + + How it works + ------------ + - This is a `dataclass` implementing the descriptor protocol (`__get__`, `__set__`) + - When used as a class attribute, `instance.attr` triggers `__get__()` + → returns either the forced override or the live value from `os.environ` + - Assigning to the attribute (`instance.attr = value`) triggers `__set__()` + → stores `_forced_value` for future reads + - You may uncomment the `os.environ[...] = value` line in `__set__` if you want + the override to persist globally in the process + + Example + ------- + ```python + class Environment: + TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "0") + + env = Environment() + print(cfg.TILELANG_PRINT_ON_COMPILATION) # Reads from os.environ (with default fallback) + cfg.TILELANG_PRINT_ON_COMPILATION = "1" # Forces value to "1" until changed/reset + ``` + + Benefits + -------- + * Centralizes all env-var keys and defaults in one place + * Live, up-to-date reads (no stale values after `import`) + * Testing convenience (override without touching the real env) + * Improves IDE discoverability and type hints + * Avoids hardcoding `os.environ.get(...)` in multiple places + """ + + key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION") + default: str # Default value if the environment variable is not set + _forced_value: str | None = None # Temporary runtime override (mainly for tests/debugging) + + def get(self): + if self._forced_value is not None: + return self._forced_value + return os.environ.get(self.key, self.default) + + def __get__(self, instance, owner): + """ + Called when the attribute is accessed. + 1. If a forced value is set, return it and log a warning + 2. Otherwise, look up the value in os.environ; return the default if missing + """ + return self.get() + + def __set__(self, instance, value): + """ + Called when the attribute is assigned to. + Stores the value as a runtime override (forced value). + Optionally, you can also sync this into os.environ for global effect. + """ + self._forced_value = value + # Uncomment the following line if you want the override to persist globally: + # os.environ[self.key] = value + + +# Utility function for environment variables with defaults +# Assuming EnvVar and CacheState are defined elsewhere +class Environment: + """ + Environment configuration for TileLang. + Handles CUDA/ROCm detection, integration paths, template/cache locations, + auto-tuning configs, and build options. + """ + + # CUDA/ROCm home directories + CUDA_HOME = _find_cuda_home() + ROCM_HOME = _find_rocm_home() + + # Path to the TileLang package root + TILELANG_PACKAGE_PATH = pathlib.Path(__file__).resolve().parent + + # External library include paths + CUTLASS_INCLUDE_DIR = EnvVar("TL_CUTLASS_PATH", None) + COMPOSABLE_KERNEL_INCLUDE_DIR = EnvVar("TL_COMPOSABLE_KERNEL_PATH", None) + + # TVM integration + TVM_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None) + TVM_LIBRARY_PATH = EnvVar("TVM_LIBRARY_PATH", None) + + # TileLang resources + TILELANG_TEMPLATE_PATH = EnvVar("TL_TEMPLATE_PATH", None) + TILELANG_CACHE_DIR = EnvVar("TILELANG_CACHE_DIR", os.path.expanduser("~/.tilelang/cache")) + TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp")) + + # Kernel Build options + TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "1") # print kernel name on compile + TILELANG_DISABLE_CACHE = EnvVar( + "TILELANG_DISABLE_CACHE", "0" + ) # disable kernel cache, usually for unit testing / debugging, high priority + TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # DEPRECATED! clear cache automatically if set + + # Kernel selection options + # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 + TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "1") + + # Auto-tuning settings + TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") + TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used + TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto + TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit + + # TVM integration + SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0") + TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None) + + def _initialize_torch_cuda_arch_flags(self) -> None: + """ + Detect target CUDA architecture and set TORCH_CUDA_ARCH_LIST + to ensure PyTorch extensions are built for the proper GPU arch. + """ + from tilelang.contrib import nvcc + from tilelang.utils.target import determine_target + + target = determine_target(return_object=True) # get target GPU + compute_version = nvcc.get_target_compute_version(target) # e.g. "8.6" + major, minor = nvcc.parse_compute_version(compute_version) # split to (8, 6) + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" # set env var for PyTorch + + # Cache control API (wrap CacheState) + def is_cache_enabled(self) -> bool: + return not self.is_cache_globally_disabled() and CacheState.is_enabled() + + def enable_cache(self) -> None: + CacheState.enable() + + def disable_cache(self) -> None: + CacheState.disable() + + def is_cache_globally_disabled(self) -> bool: + return self.TILELANG_DISABLE_CACHE.lower() in ("1", "true", "yes", "on") + + def is_autotune_cache_disabled(self) -> bool: + return self.TILELANG_AUTO_TUNING_DISABLE_CACHE.lower() in ("1", "true", "yes", "on") + + def is_print_on_compilation_enabled(self) -> bool: + return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on") + + def use_gemm_v1(self) -> bool: + """Return True if GEMM v1 should be used based on env. + + Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of + {"1", "true", "yes", "on"} (case-insensitive). + """ + return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on") + + +# Instantiate as a global configuration object +env = Environment() + +# Cache control API (wrap env, which is managed by CacheState and Environment Variables jointly) +enable_cache = env.enable_cache # CacheState.enable +disable_cache = env.disable_cache # CacheState.disable +is_cache_enabled = env.is_cache_enabled # CacheState.is_enabled + +# Export CUDA_HOME and ROCM_HOME, both are static variables +# after initialization. +CUDA_HOME = env.CUDA_HOME +ROCM_HOME = env.ROCM_HOME + + +def prepend_pythonpath(path): + if not os.environ.get("PYTHONPATH", None): + os.environ["PYTHONPATH"] = path + else: + os.environ["PYTHONPATH"] = path + os.pathsep + os.environ["PYTHONPATH"] + + sys.path.insert(0, path) + + +# Initialize TVM paths +if env.TVM_IMPORT_PYTHON_PATH is not None: + prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH) +else: + tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm", "python") + assert os.path.exists(tvm_path), tvm_path + if tvm_path not in sys.path: + prepend_pythonpath(tvm_path) + env.TVM_IMPORT_PYTHON_PATH = tvm_path +# By default, the built TVM-related libraries are stored in TL_LIBS. +if os.environ.get("TVM_LIBRARY_PATH") is None: + os.environ["TVM_LIBRARY_PATH"] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) + +# Initialize CUTLASS paths +if os.environ.get("TL_CUTLASS_PATH", None) is None: + cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, "cutlass", "include") + if os.path.exists(cutlass_inc_path): + os.environ["TL_CUTLASS_PATH"] = env.CUTLASS_INCLUDE_DIR = cutlass_inc_path + else: + logger.warning(CUTLASS_NOT_FOUND_MESSAGE) + +# Initialize COMPOSABLE_KERNEL paths +if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: + ck_inc_path = os.path.join(THIRD_PARTY_ROOT, "composable_kernel", "include") + if os.path.exists(ck_inc_path): + os.environ["TL_COMPOSABLE_KERNEL_PATH"] = env.COMPOSABLE_KERNEL_INCLUDE_DIR = ck_inc_path + else: + logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE) + +# Initialize TL_TEMPLATE_PATH +if os.environ.get("TL_TEMPLATE_PATH", None) is None: + tl_template_path = os.path.join(THIRD_PARTY_ROOT, "..", "src") + if os.path.exists(tl_template_path): + os.environ["TL_TEMPLATE_PATH"] = env.TILELANG_TEMPLATE_PATH = tl_template_path + else: + logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) + +# Export static variables after initialization. +CUTLASS_INCLUDE_DIR = env.CUTLASS_INCLUDE_DIR +COMPOSABLE_KERNEL_INCLUDE_DIR = env.COMPOSABLE_KERNEL_INCLUDE_DIR +TILELANG_TEMPLATE_PATH = env.TILELANG_TEMPLATE_PATH diff --git a/tilelang/original/tilelang/intrinsics/__init__.py b/tilelang/original/tilelang/intrinsics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3f106e71608d0be306e0ee5b63a370a80a8dc4 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/__init__.py @@ -0,0 +1,14 @@ +from .utils import ( + mma_store_index_map, # noqa: F401 + get_ldmatrix_offset, # noqa: F401 +) + +from .mma_macro_generator import ( + TensorCoreIntrinEmitter, # noqa: F401 + TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 +) + +from .mma_layout import get_swizzle_layout # noqa: F401 +from .mma_layout import make_mma_swizzle_layout # noqa: F401 + +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 diff --git a/tilelang/original/tilelang/intrinsics/mfma_layout.py b/tilelang/original/tilelang/intrinsics/mfma_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..38959649467cdfd9decd2fd73c3b4c46e8868ea0 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mfma_layout.py @@ -0,0 +1,152 @@ +from tvm import DataType +from tvm.runtime import convert +import tilelang.language as T + + +def shared_16x4_to_local_64x1_layout_A(i, j): + thread_id = j * 16 + i + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): + i = thread_id % 16 + j = thread_id // 16 + return i, j + + +def shared_4x16_to_local_64x1_layout_B(i, j): + thread_id = i * 16 + j + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): + i = thread_id // 16 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_C(i, j): + thread_id = j + (i // 4) * 16 + local = i % 4 + return thread_id, local + + +def shared_16x16_to_ldmatrix_64x4_layout(ind): + i, j = ind[0], ind[1] + thread_id, local_id = shared_16x16_to_local_64x4_layout_C(i, j) + return convert([thread_id, local_id]) + + +def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): + i = thread_id % 16 + j = (thread_id // 16) * 4 + local_id + return i, j + + +def shared_16x16_to_local_64x4_layout_A(i, j): + thread_id = i + 16 * (j // 4) + local = j % 4 + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_B(i, j): + thread_id = j + (i // 4) * 16 + local = i % 4 + return thread_id, local + + +shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A +shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A +shared_16x16_to_local_64x4_layout_n_m = shared_16x16_to_local_64x4_layout_B +shared_16x16_to_local_64x4_layout_k_n = shared_16x16_to_local_64x4_layout_B + + +def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id): + i = thread_id % 16 + j = local_id + (thread_id // 16) * 4 + return i, j + + +def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id): + i = thread_id % 16 + j = (thread_id // 16) * 8 + local_id + return i, j + + +def shared_16x32_to_local_64x8_layout_A(i, j): + thread_id = i + 16 * (j // 8) + local = j % 8 + return thread_id, local + + +def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 8 + j = thread_id % 16 + return i, j + + +def shared_16x32_to_local_64x8_layout_B(i, j): + thread_id = j + (i // 8) * 16 + local = i % 8 + return thread_id, local + + +def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id): + i = thread_id % 16 + j = local_id + (thread_id // 16) * 16 + return i, j + + +def shared_16x64_to_local_64x16_layout_A(i, j): + thread_id = i + 16 * (j // 16) + local = j % 16 + return thread_id, local + + +def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 16 + j = thread_id % 16 + return i, j + + +def shared_16x64_to_local_64x16_layout_B(i, j): + thread_id = i + 16 * (j // 16) + local = j % 16 + return thread_id, local + + +def make_mfma_swizzle_layout(shared_buf, vecSize=8): + dtype = shared_buf.dtype + shape = shared_buf.shape + + numBanks = 32 + bankBitWidth = 32 + SIMDWidth = 16 + + innerDimLength = shape[-1] + typeWidthInBit = DataType(dtype).bits + + elemsPerOneBanksRow = (numBanks * bankBitWidth) // typeWidthInBit + perPhase = max(1, elemsPerOneBanksRow // innerDimLength) + maxPhase = min(SIMDWidth // perPhase, innerDimLength // vecSize) + + def transform(row, col): + phase = (row // perPhase) % maxPhase + colOffSwizzled = ((col // vecSize) ^ phase) * vecSize + colOffOrdered = col % vecSize + colOff = colOffSwizzled + colOffOrdered + return row, colOff + + return T.Layout(shape, transform) diff --git a/tilelang/original/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/original/tilelang/intrinsics/mfma_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ad219206101368c84314546dbd9b4ce90295efef --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mfma_macro_generator.py @@ -0,0 +1,866 @@ +from __future__ import annotations +from tilelang import tvm as tvm +import tilelang.language as T +from tvm import DataType +from tvm import tir +from tvm.ir import Range +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad +from tvm.runtime import convert +from .utils import mfma_store_index_map +from typing import Literal, Callable + +from tilelang.utils import is_fragment +from tilelang.utils.language import get_buffer_region_from_load +from .mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_4x16_to_local_64x1_layout_B, + shared_16x16_to_local_64x4_layout_A, + shared_16x16_to_local_64x4_layout_B, + shared_16x32_to_local_64x8_layout_A, + shared_16x32_to_local_64x8_layout_B, + shared_16x64_to_local_64x16_layout_A, + shared_16x64_to_local_64x16_layout_B, + thread_id_shared_access_64x1_to_16x4_layout_A, + thread_id_shared_access_64x1_to_4x16_layout_B, + thread_id_shared_access_64x4_to_16x16_layout_A, + thread_id_shared_access_64x4_to_16x16_layout_B, + thread_id_shared_access_64x8_to_16x32_layout_A, + thread_id_shared_access_64x8_to_16x32_layout_B, + thread_id_shared_access_64x16_to_16x64_layout_A, + thread_id_shared_access_64x16_to_16x64_layout_B, +) + +lift = convert + + +class MatrixCoreIntrinEmitter: + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + N_DIM = 16 + WARP_SIZE = 64 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "float8_e4m3": "e4m3", + "float8_e5m2": "e5m2", + "float8_e4m3fnuz": "e4m3fnuz", + } + + # k_pack represents the number of elements in a vectorized instruction + # Detail information can be found in the triton documentation + # https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419 + k_pack = 1 + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + k_pack: int | None = None, + is_m_first: bool | None = False, + b_preshuffle: bool | None = False, + thread_var: Var | None = None, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_mfma_prefix(self.k_dim) + self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self._initialize_k_pack(k_pack) + self._initialize_is_m_first(is_m_first) + self._initialize_b_preshuffle(b_preshuffle) + + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + self.reduce_k = reduce_k + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var + + def _initialize_k_dim(self, a_dtype=T.float16): + if isinstance(a_dtype, str): + if a_dtype in ["float8_e4m3fnuz", T.int8]: + self.k_dim = 32 + return + a_dtype = DataType(a_dtype) + + if a_dtype.bits == 32: + self.k_dim = 4 + elif a_dtype.bits in {16, 8}: + self.k_dim = 16 + else: + raise ValueError(f"Unsupported a_dtype = {a_dtype}") + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mfma_prefix(self, k_dim=16): + in_dtype, out_dtype = self.a_dtype, self.accum_dtype + M_DIM, N_DIM = self.M_DIM, self.N_DIM + out_dtype_abbrv = {T.float16: "f16", T.float32: "f32", T.int8: "i8", T.int32: "i32"}[out_dtype] + + in_dtype_abbrv = { + "bfloat16": "bf16", + "float16": "f16", + "float32": "f32", + "int8": "i8", + "int32": "i32", + "float8_e4m3fnuz": "fp8", + }[in_dtype] + + if in_dtype_abbrv == "fp8": + self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8" + elif in_dtype_abbrv == "i8": + self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8" + elif in_dtype_abbrv == "bf16": + # HIP intrinsic uses ...x{K}bf16_1k without an underscore before bf16 + self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}bf16_1k" + else: + self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" + + def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): + self.micro_size_x = m_dim + self.micro_size_y = n_dim + self.micro_size_k = k_dim + + def _initialize_k_pack(self, k_pack: int | None = None): + if k_pack is not None: + self.k_pack = k_pack + + def _initialize_is_m_first(self, is_m_first: bool | None = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False): + if b_preshuffle is not None: + self.b_preshuffle = b_preshuffle + + def get_ldmatrix_index_map(self, is_b=False): + k_dim = self.k_dim * self.k_pack + transposed = self.a_transposed if not is_b else self.b_transposed + if k_dim == 4: + index_map = shared_16x4_to_local_64x1_layout_A + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A + if is_b: + index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B + reverse_index_map = ( + thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + ) + elif k_dim == 16: + index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + ) + + if is_b: + index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + ) + elif k_dim == 32: + index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + ) + + if is_b: + index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B + ) + elif k_dim == 64: + index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A + ) + + if is_b: + index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B + ) + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + return index_map, reverse_index_map + + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + index_map = IndexMap.from_func(mfma_store_index_map, index_dtype=T.int32) + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + k_pack = self.k_pack + is_transposed = self.a_transposed + thread_binding = self.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + + # legalize shared buffer to region + A_region = self._legalize_to_buffer_region(A_shared_buf) + A_buf = A_region.buffer + A_base0 = A_region.region[-2].min + A_base1 = A_region.region[-1].min + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + else: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0): + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + k_pack = self.k_pack + is_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + + # legalize shared buffer to region + B_region = self._legalize_to_buffer_region(B_shared_buf) + B_buf = B_region.buffer + B_base0 = B_region.region[-2].min + B_base1 = B_region.region[-1].min + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * (k_pack * micro_size_k), + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * chunk + ki * (k_pack * micro_size_k), + warp_n * warp_col_tiles + j * micro_size_y, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + + def mfma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + k_pack = self.k_pack + mfma_suffix = self.mfma_suffix + a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype + compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}" + compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" + compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" + + a_is_fragment = is_fragment(A_local_buf) + b_is_fragment = is_fragment(B_local_buf) + a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0 + + @T.macro + def _warp_mfma(A_local_buf, B_local_buf, C_local_buf): + for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): + T.tvm_mfma( + mfma_suffix, + "row", + "row", + compute_a_dtype, + compute_b_dtype, + compute_out_dtype, + B_local_buf.data, + (b_local_stride + (j * k_pack + kp) * local_size_b) // local_size_b, + A_local_buf.data, + (a_local_stride + (i * k_pack + kp) * local_size_a) // local_size_a, + C_local_buf.data, + (i * warp_cols * local_size_out + j * local_size_out) // local_size_out, + dtype=compute_out_dtype, + ) + + return _warp_mfma(A_local_buf, B_local_buf, C_local_buf) + + def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + thread_binding = self.get_thread_binding() + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, N_DIM = self.M_DIM, self.N_DIM + C_buf_dims = len(C_buf.shape) + assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + + # STS + # MFMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id in T.vectorized(local_size_out): + row, col = T.meta_var(mfma_store_index_map(tx, local_id)) + if C_buf_dims == 2: + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] + else: + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * warp_cols * local_size_out + j * local_size_out + local_id + ] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id in T.vectorized(local_size_out): + row, col = T.meta_var(mfma_store_index_map(tx, local_id)) + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] + + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) + + def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + assert matrix in ["A", "B"], "matrix should be either A or B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" + transposed = self.a_transposed if matrix_is_a else self.b_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + + k_dim = self.k_dim * self.k_pack + + if k_dim == 4: + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + elif k_dim == 16: + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + elif k_dim == 32: + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + elif k_dim == 64: + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) + is_sr_axis_order = any(is_sr_conditions) + + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + else: + raise ValueError(f"Unsupported matrix {matrix}") + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + if matrix_is_a: + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + else: + micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mfma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mfma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r * self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.chunk + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // (micro_size_r * self.k_pack) + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + + return block_fragment + + def make_mfma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + inverse_mfma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mfma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + # upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y + mfma_i, mfma_j = i % micro_size_x, j % micro_size_y + lane_id, _ = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j]) + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mfma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of warp_i and warp_j are warp_rows and warp_cols + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + # upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y + mfma_i, mfma_j = i % micro_size_x, j % micro_size_y + _, local_id = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j]) + return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + """ + Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + + - Buffer -> full-region BufferRegion covering entire shape + - BufferRegion -> returned as-is + - BufferLoad -> best-effort convert via get_buffer_region_from_load; + if scalar, fall back to 1-sized ranges at given indices + """ + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + # Fallback: scalar load -> 1-sized ranges at indices + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") + + +class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + k_pack: int | None = None, + is_m_first: bool | None = False, + a_preshuffle: bool | None = False, + b_preshuffle: bool | None = False, + thread_var: Var | None = None, + ): + super().__init__( + a_dtype=a_dtype, + b_dtype=b_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + num_elems_per_byte=num_elems_per_byte, + k_pack=k_pack, + is_m_first=is_m_first, + thread_var=thread_var, + ) + self._initialize_preshuffle(a_preshuffle, b_preshuffle) + + def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): + if a_preshuffle is not None: + self.a_preshuffle = a_preshuffle + if b_preshuffle is not None: + self.b_preshuffle = b_preshuffle + + def ldmatrix_a(self, A_local_buf, A_buf, ki, rk=0, pid_m=None, pid_n=None): + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + k_pack = self.k_pack + is_transposed = self.a_transposed + current_frame = T.KernelLaunchFrame.Current() + thread_binding = current_frame.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + is_global = pid_m is not None and pid_n is not None + + # no preshuffle, use the default implementation + if self.a_preshuffle is False: + return super().ldmatrix_a(A_local_buf, A_buf, ki, rk) + + def _warp_ldmatrix_a_global( + A_local_buf, + A_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + (pid_m * self.block_row_warps + warp_m) * warp_rows + i, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col] + else: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + (pid_m * self.block_row_warps + warp_m) * warp_rows + i, + rk * (chunk // micro_size_k) + ki, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col] + + @T.macro + def _warp_ldmatrix_a_shared( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_m * warp_rows + i, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] + else: + print(self.a_preshuffle) + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] + + return ( + _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk) + ) + + def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + k_pack = self.k_pack + is_transposed = self.b_transposed + current_frame = T.KernelLaunchFrame.Current() + thread_binding = current_frame.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + is_global = pid_m is not None and pid_n is not None + + if self.b_preshuffle is False: + return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n) + + @T.macro + def _warp_ldmatrix_b_global( + B_local_buf, + B_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + (pid_n * self.block_col_warps + warp_n) * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + (pid_n * self.block_col_warps + warp_n) * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col] + + @T.macro + def _warp_ldmatrix_b_shared( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + warp_n * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_n * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] + + return ( + _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk) + ) diff --git a/tilelang/original/tilelang/intrinsics/mma_layout.py b/tilelang/original/tilelang/intrinsics/mma_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb575f0ca1fc3e070e3e9402439dcfdd936131e --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mma_layout.py @@ -0,0 +1,254 @@ +from __future__ import annotations +from tvm import arith, DataType +import tilelang.language as T + + +def ldmatrix_32x4_to_shared_16x8_layout_a(thread_id, local_id): + row = thread_id % 16 + col = (thread_id // 16) * 4 + local_id % 4 + return row, col + + +def ldmatrix_32x4_to_shared_16x8_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = ((thread_id % 16) // 8) * 4 + local_id % 4 + return row, col + + +def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): + row = thread_id % 16 + col = 8 * (thread_id // 16) + local_id % 8 + return row, col + + +def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (thread_id // 16) + (thread_id % 8) + col = 8 * ((thread_id % 16) // 8) + local_id % 8 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + +def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (local_id % 4 // 2) + (thread_id // 4) + col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) + return row, col + + +def mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id): + row = thread_id // 4 + col = (thread_id % 4) * 2 + local_id + return row, col + + +# sr represents spatial + reduction layout +# the first axis is spatial while the second axis is reduction +# mma.sync matrix A layout, if wanna trans, please apply map_indices +def shared_16x8_to_mma_a_32x4_layout(i, j): + thread_id = 4 * (i % 8) + (j % 4) + return thread_id, 2 * (j // 4) + (i // 8) + + +def shared_16x8_to_mma_a_32x4_layout_trans(i, j): + return shared_16x8_to_mma_a_32x4_layout(j, i) + + +# mma.sync matrix B layout, if wanna trans, please apply map_indices +def shared_16x8_to_mma_b_32x4_layout(i, j): + thread_id = 4 * (i % 8) + (j % 4) + return thread_id, 2 * (i // 8) + (j // 4) + + +def shared_16x8_to_mma_b_32x4_layout_trans(i, j): + return shared_16x8_to_mma_b_32x4_layout(j, i) + + +shared_16x8_to_mma_32x4_layout_sr_a = shared_16x8_to_mma_a_32x4_layout +shared_16x8_to_mma_32x4_layout_sr_b = shared_16x8_to_mma_b_32x4_layout +shared_16x8_to_mma_32x4_layout_rs_a = shared_16x8_to_mma_a_32x4_layout_trans +shared_16x8_to_mma_32x4_layout_rs_b = shared_16x8_to_mma_b_32x4_layout_trans + + +def shared_16x16_to_mma_a_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + + +def shared_16x16_to_mma_a_32x8_layout_trans(i, j): + return shared_16x16_to_mma_a_32x8_layout(j, i) + + +def shared_16x16_to_mma_b_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2) + + +def shared_16x16_to_mma_b_32x8_layout_trans(i, j): + return shared_16x16_to_mma_b_32x8_layout(j, i) + + +shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout +shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout +shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans +shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans + + +def shared_16x32_to_mma_a_32x16_layout(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 + + +def shared_32x16_to_mma_a_32x16_layout_trans(i, j): + return shared_16x32_to_mma_a_32x16_layout(j, i) + + +def shared_16x32_to_mma_b_32x16_layout(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 8 * (i // 8) + (j // 16) * 4 + j % 4 + + +def shared_32x16_to_mma_b_32x16_layout_trans(i, j): + return shared_16x32_to_mma_b_32x16_layout(j, i) + + +shared_16x32_to_mma_32x16_layout_sr_a = shared_16x32_to_mma_a_32x16_layout +shared_16x32_to_mma_32x16_layout_sr_b = shared_16x32_to_mma_b_32x16_layout +shared_16x32_to_mma_32x16_layout_rs_a = shared_32x16_to_mma_a_32x16_layout_trans +shared_16x32_to_mma_32x16_layout_rs_b = shared_32x16_to_mma_b_32x16_layout_trans + + +def mma_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (local_id % 4 // 2) + (thread_id // 4) + col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) + return row, col + + +def mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id): + row = 8 * (local_id % 2) + (thread_id // 4) + col = 4 * (local_id // 2) + (thread_id % 4) + return row, col + + +def mma_load_b_32x4_to_shared_16x8_layout(thread_id, local_id): + row = 8 * (local_id // 2) + (thread_id // 4) + col = 4 * (local_id % 2) + (thread_id % 4) + return row, col + + +def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): + row = 8 * (local_id % 8 // 4) + (thread_id // 4) + col = 16 * (local_id // 8) + (thread_id % 4) * 4 + (local_id % 4) + return row, col + + +def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id): + """ + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 + + row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 + groupID + 8 Otherwise + + col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 + (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 + """ + row = (thread_id // 4) + 8 * (local_id % 4 // 2) + col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4) + return row, col + + +def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): + row = 8 * (local_id // 8) + (thread_id // 4) + col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4) + return row, col + + +def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): + """ + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 + + row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 + (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 + + col = groupID + """ + col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8 + row = (thread_id // 4) + 8 * (local_id // 4) + return row, col + + +def shared_16x16_to_mma_32x8_smoothlayout(i, j): + return (i * 2 + j // 8, j % 8) + + +def shared_16x32_to_mma_32x16_smoothlayout(i, j): + return (i * 2 + j // 16, j % 16) + + +def shared_32x16_to_mma_32x16_smoothlayout(i, j): + return (i * 2 + j // 16, j % 16) + + +def get_swizzle_layout(row_idx, col_idx, row_size, dtype: DataType | str, swizzle_bytes=None): + ana = arith.Analyzer() + if isinstance(dtype, str): + dtype = DataType(dtype) + row_bytes = dtype.bits * row_size // 8 + assert row_bytes % 32 == 0, "Row size must be multiple of 32B." + if swizzle_bytes is None: + swizzle_bytes = min(128, row_bytes) + # 128B swizzle + # Use 8 * 8 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 + # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 + # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 + # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 + # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 + # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 + # 64B swizzle + # Use 8 * 4 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 + # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 + # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 + # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 + # 32B swizzle + # Use 8 * 2 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + elem_per_16B = 128 // dtype.bits + col_idx_16B = col_idx // elem_per_16B + col_idx_in_16B = col_idx % elem_per_16B + new_col_idx_16B = col_idx_16B ^ (row_idx % (swizzle_bytes // 16)) + return row_idx, ana.simplify(new_col_idx_16B * elem_per_16B + col_idx_in_16B) + + +def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits % 512 == 0 + if is_smooth or (not can_swizzle): + return T.Layout(shape, lambda *args: args) + + def transform_func(*args): + i, j = args[-2:] + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [*args[:-2], new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) diff --git a/tilelang/original/tilelang/intrinsics/mma_macro_generator.py b/tilelang/original/tilelang/intrinsics/mma_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..4b41eef2af58e515d48b80d9a81adbadc0867f19 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mma_macro_generator.py @@ -0,0 +1,1358 @@ +from __future__ import annotations +import tilelang.language as T +from typing import Literal, Callable +from tilelang.common import TransformKind +from tvm import DataType +from tvm import tir +from tvm.ir import Range +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad +from tilelang import tvm as tvm +from tvm.runtime import convert +from .utils import ( + mma_store_index_map, + get_ldmatrix_offset, +) +from tilelang.utils import is_fragment, get_buffer_region_from_load +from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x8_to_mma_32x4_layout_sr_b, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_b, + shared_16x32_to_mma_32x16_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_b, + mma_load_a_32x4_to_shared_16x8_layout, + mma_load_b_32x4_to_shared_16x8_layout, + mma_load_b_32x8_to_shared_16x16_layout, + mma_load_a_32x16_to_shared_16x32_layout, + mma_load_b_32x16_to_shared_16x32_layout, + mma_load_a_32x8_to_shared_16x16_layout, +) + +lift = convert + + +class TensorCoreIntrinEmitter: + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + # use lowercase as n_dim can be dynamic + # the smallest instructions can be m16n8k16, so the n_dim can also be 8 + n_dim = 16 + WARP_SIZE = 32 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "int8", + "int32": "int32", + "float8_e4m3": "e4m3", + "float8_e4m3fn": "e4m3", + "float8_e4m3fnuz": "e4m3", + "float8_e5m2": "e5m2", + "float8_e5m2fnuz": "e5m2", + } + + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool | None = False, + thread_var: Var | None = None, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + # For FP64, MMA shape is m8n8k4; adjust instance dims early + if DataType(a_dtype).bits == 64: + # Override default M/N dims for fp64 MMA + self.M_DIM = 8 + # n_dim will be set to 8 in _initialize_micro_size via k_dim==4 + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_micro_size(self.M_DIM, self.k_dim) + self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) + self._initialize_mma_prefix(self.k_dim) + self._initialize_is_m_first(is_m_first) + + self.reduce_k = reduce_k + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var + + if self.warp_rows == 0 or self.warp_cols == 0: + raise ValueError( + f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" + ) + + def _initialize_k_dim(self, a_dtype=T.float16): + if isinstance(a_dtype, str): + a_dtype = DataType(a_dtype) + self.k_dim = 256 // a_dtype.bits + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype) + self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype) + self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype) + + def _get_dtype_abbrv(self, dtype: str) -> str: + try: + return self.dtype_abbrv[dtype] + except KeyError as err: + raise ValueError(f"Unsupported dtype: {dtype}") from err + + def _initialize_mma_prefix(self, k_dim: int = 16): + if k_dim == 4: + # fp64 + self.mma_prefix = "m8n8k4" + elif k_dim == 8: + # typically used for tfloat32 + self.mma_prefix = "m16n8k8" + elif k_dim == 16: + # typically used for float16/bfloat16 + self.mma_prefix = "m16n8k16" + elif k_dim == 32: + # typically used for int8/fp8 + self.mma_prefix = "m16n8k32" + else: + raise ValueError("Unsupported k_dim") + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + # For fp64 (k_dim==4), micro tile is 8x8, otherwise keep 16x{8|16} + if k_dim == 4: + # fp64 path: m_dim must be 8, n_dim 8 + assert m_dim == 8, f"For fp64 MMA, m_dim must be 8, got {m_dim}" + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_rows = warp_row_tiles // m_dim + self.warp_cols = warp_col_tiles // 8 + else: + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + self.warp_rows = warp_row_tiles // m_dim + + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + self.micro_size_k = k_dim + + def _initialize_is_m_first(self, is_m_first: bool | None = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + from .utils import mma_store_index_map, mma_store_index_map_fp64 + + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + if DataType(self.accum_dtype).bits == 64: + index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype=T.int32) + else: + index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32) + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): + # Fast path for fp64: no ldmatrix support, do direct per-lane loads + if DataType(self.a_dtype).bits == 64: + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x # 8 + micro_size_k = self.micro_size_k # 4 + local_size_a = self.local_size_a # 1 + a_transposed = self.a_transposed + + thread_binding = self.get_thread_binding() + # legalize shared buffer to region + A_region = self._legalize_to_buffer_region(A_shared_buf) + A_buf = A_region.buffer + A_base0 = A_region.region[-2].min + A_base1 = A_region.region[-1].min + + @T.macro + def _warp_ld_a_fp64( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + for i in T.serial(warp_rows): + wi = warp_m * warp_row_tiles + i * micro_size_x + wk = rk * chunk + ki * micro_size_k + mi = tx // micro_size_k + mk = tx % micro_size_k + if a_transposed: + A_local_buf[i * local_size_a] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi] + else: + A_local_buf[i * local_size_a] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] + + return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk) + + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_dtype = self.a_dtype + a_transposed = self.a_transposed + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(a_dtype).bits == 8: + mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout + elif DataType(a_dtype).bits == 16: + mma_load_layout = mma_load_a_32x8_to_shared_16x16_layout + elif DataType(a_dtype).bits == 32: + mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout + else: + raise ValueError(f"Unsupported dtype: {a_dtype}") + + thread_binding = self.get_thread_binding() + + # legalize shared buffer to region + A_region = self._legalize_to_buffer_region(A_shared_buf) + A_buf = A_region.buffer + A_base0 = A_region.region[-2].min + A_base1 = A_region.region[-1].min + A_stride_last = A_buf.shape[-1] + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = A_stride_last + tx, _, warp_m = self.extract_thread_binding(thread_binding) + trans = self.a_transposed + + for i in T.serial(warp_rows): + # Assign A_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k + A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk] + + if ldmatrix_available: + T.ptx_ldmatrix( + a_dtype, + T.bool(trans), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_buf_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + else: + for j in T.serial(local_size_a): + mi, mk = mma_load_layout(tx, j) + if a_transposed: + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi] + else: + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] + + return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): + # Fast path for fp64: no ldmatrix support, do direct per-lane loads + if DataType(self.b_dtype).bits == 64: + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y # 8 + micro_size_k = self.micro_size_k # 4 + local_size_b = self.local_size_b # 1 + b_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + + # legalize shared buffer to region + B_region = self._legalize_to_buffer_region(B_shared_buf) + B_buf = B_region.buffer + B_base0 = B_region.region[-2].min + B_base1 = B_region.region[-1].min + + @T.macro + def _warp_ld_b_fp64( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + for j in T.serial(warp_cols): + wi = warp_n * warp_col_tiles + j * micro_size_y + wk = rk * chunk + ki * micro_size_k + mi = tx // micro_size_k + mk = tx % micro_size_k + if b_transposed: + B_local_buf[j * local_size_b] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] + else: + B_local_buf[j * local_size_b] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] + + return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk) + + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_dtype = self.b_dtype + b_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + + # legalize shared buffer to region + B_region = self._legalize_to_buffer_region(B_shared_buf) + B_buf = B_region.buffer + B_base0 = B_region.region[-2].min + B_base1 = B_region.region[-1].min + B_stride_last = B_buf.shape[-1] + replicate_b = self.n_dim == 16 + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(b_dtype).bits == 8: + mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout + elif DataType(b_dtype).bits == 16: + mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout + elif DataType(b_dtype).bits == 32: + mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout + else: + raise ValueError(f"Unsupported dtype: {b_dtype}") + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = B_stride_last + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + trans = not b_transposed + + for i in T.serial(warp_cols): + # Assign B_shared_elem + wi, wk = ( + warp_n * warp_col_tiles + i * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + + if ldmatrix_available: + B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi] + + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4 if replicate_b else 2, + ".b16", + B_local_buf.data, + i * local_size_b, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + ) + + else: + # load 16x32 data from shared buffer to local buffer + # must be transposed. + for j in T.serial(local_size_b): + mi, mk = mma_load_layout(tx, j) + if b_transposed: + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] + else: + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + + def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + replicate_b = self.n_dim == 16 + + a_is_fragment = is_fragment(A_local_buf) + b_is_fragment = is_fragment(B_local_buf) + a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), # saturate + ) + if replicate_b: + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), # saturate + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, n_dim = self.M_DIM, self.n_dim + C_buf_dims = len(C_buf.shape) + assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + + thread_binding = self.get_thread_binding() + + # STS + # MMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + if C_buf_dims == 2: + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] + else: + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, + (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] + + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + assert matrix in ["A", "B"], "matrix should be either A or B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" + dtype = self.a_dtype if matrix_is_a else self.b_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed if matrix_is_a else self.b_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b + elif dtype_bits == 8: + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) + is_sr_axis_order = any(is_sr_conditions) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + else: + raise ValueError(f"Unsupported matrix {matrix}") + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + if matrix_is_a: + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + else: + micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.chunk + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // micro_size_r + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + assert is_fragment(local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}" + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of warp_i and warp_j are warp_rows and warp_cols + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + _, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + """ + Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + + - Buffer -> full-region BufferRegion covering entire shape + - BufferRegion -> returned as-is + - BufferLoad -> best-effort convert via get_buffer_region_from_load; + if scalar, fall back to 1-sized ranges at given indices + """ + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + # Fallback: scalar load -> 1-sized ranges at indices + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") + + +class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): + """ + To eliminate Python syntax within TIR Macro. + With Ladder Transform Plugin. + """ + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool | None = False, + transform_kind_a: int | TransformKind = 0, + transform_kind_b: int | TransformKind = 0, + ): + super().__init__( + a_dtype=a_dtype, + b_dtype=b_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + num_elems_per_byte=num_elems_per_byte, + is_m_first=is_m_first, + ) + self._initialize_transform_kind(transform_kind_a, transform_kind_b) + + def _initialize_k_dim(self, a_dtype=T.float16): + self.k_dim = 256 // DataType(a_dtype).bits + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mma_prefix(self, k_dim=16): + if k_dim == 16: + self.mma_prefix = "m16n8k16" + elif k_dim == 32: + self.mma_prefix = "m16n8k32" + else: + raise ValueError("Unsupported k_dim") + + def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): + self.micro_size_x = m_dim + self.micro_size_y = n_dim + self.micro_size_k = k_dim + + def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): + if isinstance(transform_kind_a, int): + self.transform_kind_a = TransformKind(transform_kind_a) + elif isinstance(transform_kind_a, TransformKind): + self.transform_kind_a = transform_kind_a + else: + raise ValueError("Unsupported transform_kind_a") + + if isinstance(transform_kind_b, int): + self.transform_kind_b = TransformKind(transform_kind_b) + elif isinstance(transform_kind_b, TransformKind): + self.transform_kind_b = transform_kind_b + else: + raise ValueError("Unsupported transform_kind_b") + + assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3" + assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3" + + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_dtype = self.a_dtype + a_transposed = self.a_transposed + transform_kind_a = self.transform_kind_a + + thread_binding = self.get_thread_binding() + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = A_shared_buf.shape[-1] + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if transform_kind_a == TransformKind.NonTransform: + for i in T.serial(warp_rows): + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of( + A_shared_buf[ + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ] + ), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + elif transform_kind_a == TransformKind.InterWarpTransform: + for i in T.serial(warp_rows): + # Assign B_shared_elem + ri, rj = ( + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_x, + (rj) // micro_size_k, + (ri) % micro_size_x, + (rj) % micro_size_k, + ) + args = (ni, nj, nii, njj) if transform_kind_a > 0 else (ri, rj) + A_shared_elem = A_shared_buf[args] + + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + elif transform_kind_a == TransformKind.IntraWarpTransform: + for i in T.serial(warp_rows): + # Assign B_shared_elem + ri, rj = ( + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_x, + (rj) // micro_size_k, + (ri) % micro_size_x, + (rj) % micro_size_k, + ) + A_shared_elem = A_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_elem), + tx * local_size_a, + ) + elif transform_kind_a == TransformKind.LDMatrixTransform: + for j in T.serial(warp_rows): + for local_id in T.vectorized(local_size_a): + # Assign A_shared_elem + ri, rj = ( + warp_m * warp_rows + j, + rk * (chunk // micro_size_k) + ki, + ) + rii, rjj = (tx * local_size_a + local_id) // micro_size_k, (tx * local_size_a + local_id) % (micro_size_k) + A_local_buf[j * local_size_a + local_id] = A_shared_buf[ri, rj, rii, rjj] + else: + raise ValueError("Unsupported TransformKind for Input A") + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_dtype = self.b_dtype + transform_kind_b = self.transform_kind_b + b_transposed = self.b_transposed + num_elems_per_byte = self.num_elems_per_byte + + thread_binding = self.get_thread_binding() + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = B_shared_buf.shape[-1] + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + + if transform_kind_b == TransformKind.NonTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + B_shared_elem = B_shared_buf[ri, rj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + ) + elif transform_kind_b == TransformKind.InterWarpTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_y, + (rj) // micro_size_k, + (ri) % micro_size_y, + (rj) % micro_size_k, + ) + B_shared_elem = B_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + ) + elif transform_kind_b == TransformKind.IntraWarpTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_y, + (rj) // micro_size_k, + (ri) % micro_size_y, + (rj) % micro_size_k, + ) + B_shared_elem = B_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + tx * local_size_b, + ) + elif transform_kind_b == TransformKind.LDMatrixTransform: + local_size_dequantize = local_size_b // num_elems_per_byte + for j in T.serial(warp_cols): + for local_id in T.vectorized(local_size_dequantize): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + rii, rjj = ( + (tx * local_size_dequantize + local_id) // (micro_size_k // num_elems_per_byte), + (tx * local_size_dequantize + local_id) % (micro_size_k // num_elems_per_byte), + ) + B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, rjj] + else: + raise ValueError("Unsupported TransformKind for Input B") + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + +class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = "int4" + b_dtype_abbrv = "int4" + accum_dtype = self.accum_dtype + accum_dtype_abbrv = accum_dtype + mma_prefix = "m16n8k32" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + """ + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + """ + + # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_a) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_b) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + +class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = "int4" + b_dtype_abbrv = "int4" + accum_dtype = self.accum_dtype + accum_dtype_abbrv = T.int32 + mma_prefix = "m16n8k32" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + """ + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + """ + + # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_a) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_b) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/tilelang/original/tilelang/intrinsics/mma_sm70_layout.py b/tilelang/original/tilelang/intrinsics/mma_sm70_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..8029234414710af6e923b9d68e0df9a4fc9f4801 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mma_sm70_layout.py @@ -0,0 +1,46 @@ +def shared_16x4_to_mma_a_32x4_layout(row, col, rep): + tid = (row % 4) + 16 * ((row // 4) % 2) + 4 * (row // 8) + 8 * rep + local_id = col + return tid, local_id + + +def shared_4x16_to_mma_b_32x4_layout(row, col, rep): + thread_id = row + 8 * col // 4 + 4 * rep + local_id = col % 4 + return thread_id, local_id + + +def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep): + thread_id = row % 4 + 4 * rep + 8 * ((row % 8) // 4) + 16 * (row // 8) + local_id = col + return thread_id, local_id + + +def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id): + row = (thread_id % 2) + ((local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8 + col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % 2) + (local_id // 4) * 8 + return row, col + + +def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id): + row = (thread_id % 4) + (thread_id // 16) * 4 + (thread_id % 8) // 4 * 8 + col = local_id % 4 + ((thread_id % 16) // 8) * 4 + (local_id // 4) * 8 + return row, col + + +def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id): + row = (thread_id % 4) + (4 * ((thread_id // 16 + thread_id % 16 // 4 * 2) % 4)) + col = local_id + return row, col + + +def mma_load_b_32x4_to_shared_16x4_layout_trans(thread_id, local_id): + row = (thread_id % 4) + 8 * (thread_id // 16) + 4 * ((thread_id // 8) % 2) + col = local_id + return row, col + + +def mma_load_b_32x4_to_shared_4x16_layout(thread_id, local_id): + row = thread_id % 4 + col = local_id + (4 * (thread_id // 8)) + return row, col diff --git a/tilelang/original/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/original/tilelang/intrinsics/mma_sm70_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6acc40a4cd56869493bb74465de7e91dd64e7ba4 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mma_sm70_macro_generator.py @@ -0,0 +1,495 @@ +from __future__ import annotations +import tilelang.language as T +from typing import Literal, Callable +from tvm import DataType +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion +from tilelang import tvm as tvm +from tvm.runtime import convert +from tilelang.utils import is_fragment +from tilelang.intrinsics.mma_sm70_layout import ( + shared_16x4_to_mma_a_32x4_layout, + shared_4x16_to_mma_b_32x4_layout, + shared_16x4_to_mma_b_32x4_layout_trans, + mma_32x8_to_shared_16x16_layout_fp32, + mma_32x8_to_shared_16x16_layout_fp16, + mma_load_a_32x4_to_shared_16x4_layout, + mma_load_b_32x4_to_shared_16x4_layout_trans, + mma_load_b_32x4_to_shared_4x16_layout, +) + +lift = convert + + +class TensorCoreIntrinEmitter: + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + # use lowercase as n_dim can be dynamic + # the smallest instructions can be m16n8k16, so the n_dim can also be 8 + n_dim = 16 + WARP_SIZE = 32 + HALF_WARP_SIZE = WARP_SIZE // 2 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "float8_e4m3": "e4m3", + "float8_e5m2": "e5m2", + } + + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool | None = False, + thread_var: Var | None = None, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_micro_size(self.M_DIM, self.k_dim) + self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim) + self._initialize_mma_prefix(self.k_dim) + self._initialize_is_m_first(is_m_first) + + self.reduce_k = reduce_k + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var + + if self.warp_rows == 0 or self.warp_cols == 0: + raise ValueError( + f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" + ) + + def _initialize_k_dim(self, a_dtype=T.float16): + self.k_dim = 4 + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16): + self.local_size_a = (m_dim * k_dim) // self.HALF_WARP_SIZE + self.local_size_b = (n_dim * k_dim) // self.HALF_WARP_SIZE + self.local_size_out = (m_dim * n_dim) // self.WARP_SIZE + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype) + self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype) + self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype) + + def _get_dtype_abbrv(self, dtype: str) -> str: + try: + return self.dtype_abbrv[dtype] + except KeyError as err: + raise ValueError(f"Unsupported dtype: {dtype}") from err + + def _initialize_mma_prefix(self, k_dim: int = 16): + if k_dim == 4: + # typically used for float16 + self.mma_prefix = "m16n16k4" + else: + raise ValueError(f"Unsupported k_dim: {k_dim}") + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 16, f"warp_col_tiles must be greater than 16, got {warp_col_tiles}" + assert warp_col_tiles % 16 == 0, f"warp_col_tiles must be divisible by 16, got {warp_col_tiles}" + + self.warp_rows = warp_row_tiles // m_dim + + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + + self.micro_size_x = m_dim + self.micro_size_k = k_dim + + def _initialize_is_m_first(self, is_m_first: bool | None = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + index_map = IndexMap.from_func( + mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == T.float32 else mma_32x8_to_shared_16x16_layout_fp16, + index_dtype=T.int32, + ) + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_transposed = self.a_transposed + + thread_binding = self.get_thread_binding() + + assert not a_transposed, "A must be not transposed" + + mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout + + # legalize shared buffer to region + A_region = self._legalize_to_buffer_region(A_shared_buf) + A_buf = A_region.buffer + A_base0 = A_region.region[-2].min + A_base1 = A_region.region[-1].min + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + + for i in T.serial(warp_rows): + # Assign A_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k + for j in T.vectorized(local_size_a): + mi, mk = mma_load_layout(tx, j) + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] + + return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + + mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout + + # legalize shared buffer to region + B_region = self._legalize_to_buffer_region(B_shared_buf) + B_buf = B_region.buffer + B_base0 = B_region.region[-2].min + B_base1 = B_region.region[-1].min + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + + for i in T.serial(warp_cols): + # Assign B_shared_elem + wi, wk = ( + warp_n * warp_col_tiles + i * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + # load 16x32 data from shared buffer to local buffer + # must be transposed. + for j in T.vectorized(local_size_b): + if b_transposed: + mi, mk = mma_load_layout(tx, j) + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] + else: + mk, mi = mma_load_layout(tx, j) + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] + + return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk) + + def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + + a_is_fragment = is_fragment(A_local_buf) + b_is_fragment = is_fragment(B_local_buf) + a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + + a_major = "col" if self.a_transposed else "row" + b_major = "col" if self.b_transposed else "row" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma_sm70( + mma_prefix, + a_major, + b_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + assert matrix in ["A", "B"], "matrix should be either A or B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" + dtype = self.a_dtype if matrix_is_a else self.b_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed if matrix_is_a else self.b_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + transform_func_rs_b: Callable = None + if dtype_bits == 16: + transform_func_sr_a = shared_16x4_to_mma_a_32x4_layout + transform_func_sr_b = shared_16x4_to_mma_b_32x4_layout_trans + transform_func_rs_b = shared_4x16_to_mma_b_32x4_layout + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) + is_sr_axis_order = any(is_sr_conditions) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(i, j) + else: + raise ValueError(f"Unsupported matrix {matrix}") + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + if matrix_is_a: + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + else: + micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward(i: int, j: int, rep: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, local_id = inverse_mma_load_layout.map_indices([i, j, rep]) + return lane_id, local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_fn=forward, replicate=2 + ) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.chunk + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // micro_size_r + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of warp_i and warp_j are warp_rows and warp_cols + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + _, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) diff --git a/tilelang/original/tilelang/intrinsics/mma_sp_layout.py b/tilelang/original/tilelang/intrinsics/mma_sp_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..58034e7fdba90cf6bc8408db6e3ad8c10ce0661a --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mma_sp_layout.py @@ -0,0 +1,181 @@ +from tvm import DataType +from typing import Literal + +from tilelang.intrinsics.mma_layout import ( + mma_load_a_32x4_to_shared_16x8_layout, + mma_load_a_32x16_to_shared_16x32_layout, + mma_load_a_32x8_to_shared_16x16_layout, + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, +) + + +def shared_16x16_to_mma_sp_layout_sr_a(i, j): + return shared_16x8_to_mma_32x4_layout_sr_a(i, j) + + +def shared_16x16_to_mma_sp_layout_sr_b(i, j): + thread_id = 4 * (i % 8) + (j % 4) + return thread_id, 4 * (i // 8) + (j // 4) + + +def shared_16x32_to_mma_sp_layout_sr_a(i, j): + return shared_16x16_to_mma_32x8_layout_sr_a(i, j) + + +def shared_16x32_to_mma_sp_layout_sr_b(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 8 * (i // 8) + (j // 8) * 2 + (j % 2) + + +def shared_16x64_to_mma_sp_layout_sr_a(i, j): + return shared_16x32_to_mma_32x16_layout_sr_a(i, j) + + +def shared_16x64_to_mma_sp_layout_sr_b(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 16 * (i // 8) + (j // 16) * 4 + j % 4 + + +def mma_sp_load_a_32x4_to_shared_16x16_layout(thread_id, local_id): + return mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id) + + +def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id): + return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id) + + +def mma_sp_load_a_32x16_to_shared_16x64_layout(thread_id, local_id): + return mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id) + + +def mma_sp_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): + col = 4 * (local_id % 4) + (thread_id % 4) + row = 8 * (local_id // 4) + (thread_id // 4) + return row, col + + +def mma_sp_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): + col = (thread_id % 4) * 2 + (local_id % 2) + ((local_id % 8) // 2) * 8 + row = (thread_id // 4) + 8 * (local_id // 8) + return row, col + + +def mma_sp_load_b_32x32_to_shared_16x64_layout(thread_id, local_id): + col = (thread_id % 4) * 4 + (local_id % 4) + 16 * ((local_id % 16) // 4) + row = (thread_id // 4) + 8 * (local_id // 16) + return row, col + + +def get_logical_id_32bit(thread_id: int) -> int: + return (thread_id // 4) * 2 + (thread_id % 4) % 2 + + +def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_32bit(thread_id) + row = logical_id // 4 + local_id * 8 + col = logical_id % 4 + return row, col + + +def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_32bit(thread_id) + row = logical_id // 2 + local_id * 8 + col = logical_id % 2 + return row, col + + +def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: + return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit + + +def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: + return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit + + +def get_logical_id_8bit(thread_id: int) -> int: + return thread_id + + +def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_8bit(thread_id) + row = logical_id // 2 + local_id * 8 + col = (logical_id % 4) // 2 * 4 + local_id + return row, col + + +def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: + logical_id = get_logical_id_8bit(thread_id) + row = logical_id // 2 + local_id * 8 + col = (logical_id % 4) // 2 * 2 + local_id + return row, col + + +def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: + # local_id is always 0 + logical_id = get_logical_id_8bit(thread_id) + row = logical_id // 4 + (logical_id % 2) * 8 + col = (logical_id % 4) // 2 + return row, col + + +def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): + row = (local_id // 4) * 8 + thread_id % 8 + col = (thread_id // 8) * 4 + local_id % 4 + return row, col + + +def ldmatrix_32x16_to_shared_32x16_layout(thread_id, local_id): + row = thread_id + col = local_id % 8 + 8 * (local_id // 8) + return row, col + + +def ldmatrix_trans_32x16_to_shared_16x32_layout(thread_id, local_id): + row = 8 * (local_id // 8) + thread_id % 8 + col = (thread_id // 8) * 8 + local_id % 8 + return row, col + + +def ldmatrix_trans_32x32_to_shared_shared_16x64_layout(thread_id, local_id): + row = (local_id // 16) * 8 + thread_id % 8 + col = (thread_id // 8) * 16 + local_id % 16 + return row, col + + +def get_ldmatrix_offset_b( + matrix: Literal["B"], + row_idx, + col_idx, + stride, + dtype: Literal["float16", "int8"] = "float16", + transposed: bool = False, +): + assert matrix == "B", "matrix should be B" + dtype_bits = DataType(dtype).bits + if dtype_bits == 32: + if transposed: + transform_func = ldmatrix_trans_32x8_to_shared_16x16_layout + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed for 32-bit dtype") + elif dtype_bits == 16: + transform_func = ldmatrix_32x16_to_shared_32x16_layout + transform_func_trans = ldmatrix_trans_32x16_to_shared_16x32_layout + if transposed: + new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif dtype_bits == 8: + if transposed: + transform_func = ldmatrix_trans_32x32_to_shared_shared_16x64_layout + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed for 8-bit dtype") + else: + raise ValueError(f"Unsupported dtype {dtype}") diff --git a/tilelang/original/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/original/tilelang/intrinsics/mma_sp_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3e375b46b97547916e2aaf7efd56c3926b7af5b4 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mma_sp_macro_generator.py @@ -0,0 +1,831 @@ +from __future__ import annotations + +import tilelang.language as T +from typing import Literal, Callable +from tvm import DataType +from tvm.tir import PrimExpr, IndexMap, Buffer, Var +from tvm.runtime import convert +from .utils import ( + mma_store_index_map, + get_ldmatrix_offset, +) +from tilelang.utils import is_fragment + +from tilelang.intrinsics.mma_sp_layout import ( + shared_16x16_to_mma_sp_layout_sr_a, + shared_16x16_to_mma_sp_layout_sr_b, + shared_16x32_to_mma_sp_layout_sr_a, + shared_16x32_to_mma_sp_layout_sr_b, + shared_16x64_to_mma_sp_layout_sr_a, + shared_16x64_to_mma_sp_layout_sr_b, + mma_sp_load_a_32x4_to_shared_16x16_layout, + mma_sp_load_a_32x8_to_shared_16x32_layout, + mma_sp_load_a_32x16_to_shared_16x64_layout, + mma_sp_load_b_32x8_to_shared_16x16_layout, + mma_sp_load_b_32x16_to_shared_16x32_layout, + mma_sp_load_b_32x32_to_shared_16x64_layout, + metadata_8bit_load_32x4_to_shared_16x4_layout_32bit, + metadata_16bit_load_32x2_to_shared_16x2_layout_32bit, + metadata_8bit_load_32x4_to_shared_16x4_layout_16bit, + metadata_16bit_load_32x2_to_shared_16x2_layout_16bit, + metadata_8bit_load_32x4_to_shared_16x4_layout_8bit, + metadata_16bit_load_32x2_to_shared_16x4_layout_8bit, + metadata_32bit_load_32x1_to_shared_16x2_layout_8bit, + get_ldmatrix_offset_b, +) + +lift = convert + + +class SparseTensorCoreIntrinEmitter: + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + SPARSE_FACTOR = 2 # 1:2 for tfloat12, 2:4 for 16-bit and 8-bit datatypes + SPARSE_SELECTOR = 0 # always use lower threads to provide metadata + # use lowercase as n_dim can be dynamic + # the smallest instructions can be m16n8k16, so the n_dim can also be 8 + n_dim = 16 + WARP_SIZE = 32 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "float8_e4m3": "e4m3", + "float8_e5m2": "e5m2", + } + + E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor + "float": { + "int16": 8, + "uint16": 8, + }, + "float32": { + "int16": 8, + "uint16": 8, + }, + "float16": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "bfloat16": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "int8": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "uint8": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "float8_e4m3": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + "float8_e5m2": { + "int8": 8, + "uint8": 8, + "int16": 16, + "uint16": 16, + "int32": 32, + "uint32": 32, + }, + } + + E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads + "float32": 2, + "float16": 2, # 2 of 4 consecutive threads provides + "bfloat16": 2, + "int8": 1, # 4 of 4 consecutive threads provides + "uint8": 1, + "float8_e4m3": 1, + "float8_e5m2": 1, + } + + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + + def __init__( + self, + a_dtype: str = T.float16, + e_dtype: str = T.uint8, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + e_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + warp_k: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool = False, + thread_var: Var | None = None, + ): + self.a_dtype = a_dtype + self.e_dtype = e_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + self.e_transposed = e_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.warp_k = warp_k + self.e_factor = self.E_FACTOR_MAP[self.a_dtype][self.e_dtype] + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_micro_size(self.M_DIM, self.k_dim) + self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) + self._initialize_mma_sp_prefix(self.k_dim) + self._initialize_is_m_first(is_m_first) + + self.reduce_k = reduce_k + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var + + if self.warp_rows == 0 or self.warp_cols == 0: + raise ValueError( + f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" + ) + + def _initialize_k_dim(self, a_dtype=T.float16): + if isinstance(a_dtype, str): + a_dtype = DataType(a_dtype) + # NOTE: k_dim here represents the logical shape of the MMA operation. + # When referring to the physical data movement, it should be divided by sparse_factor. + self.k_dim = 256 // a_dtype.bits * self.SPARSE_FACTOR + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR + self.local_size_e = (m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mma_sp_prefix(self, k_dim: int = 16): + if k_dim == 16: + # typically used for tfloat32 + self.mma_prefix = "m16n8k16" + elif k_dim == 32: + # typically used for float16/bfloat16 + self.mma_prefix = "m16n8k32" + elif k_dim == 64: + # typically used for int8/fp8 + self.mma_prefix = "m16n8k64" + else: + raise ValueError("Unsupported k_dim") + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + self.warp_rows = warp_row_tiles // m_dim + + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + # NOTE: k_dim here represents the logical shape of the MMA operation. + self.micro_size_k = k_dim + + def _initialize_is_m_first(self, is_m_first: bool | None = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32) + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + warp_k = self.warp_k + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_dtype = self.a_dtype + a_transposed = self.a_transposed + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(a_dtype).bits == 8: + mma_load_layout = mma_sp_load_a_32x16_to_shared_16x64_layout + elif DataType(a_dtype).bits == 16: + mma_load_layout = mma_sp_load_a_32x8_to_shared_16x32_layout + elif DataType(a_dtype).bits == 32: + mma_load_layout = mma_sp_load_a_32x4_to_shared_16x16_layout + else: + raise ValueError(f"Unsupported dtype: {a_dtype}") + + thread_binding = self.get_thread_binding() + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = A_shared_buf.shape[-1] + tx, _, warp_m = self.extract_thread_binding(thread_binding) + trans = self.a_transposed + + for i in T.serial(warp_rows): + # Assign A_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR + A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] + + if ldmatrix_available: + T.ptx_ldmatrix( + a_dtype, + T.bool(trans), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_buf_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + else: + for j in T.serial(local_size_a): + mi, mk = mma_load_layout(tx, j) + A_local_buf[i * local_size_a + j] = ( + A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk] + ) + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) + + def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + warp_k = self.warp_k + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_e = self.local_size_e + a_dtype = self.a_dtype + e_dtype = self.e_dtype + trans = self.e_transposed + # ldmatrix cannot be used for int8 + trans case. + # include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h + ldmatrix_available = False # TODO: use ldmatrix when possible + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(e_dtype).bits == 8: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_8bit + elif DataType(a_dtype).bits == 16: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_16bit + elif DataType(a_dtype).bits == 32: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_32bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 8bit: {a_dtype}") + elif DataType(e_dtype).bits == 16: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit + elif DataType(a_dtype).bits == 16: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit + elif DataType(a_dtype).bits == 32: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}") + elif DataType(e_dtype).bits == 32: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_32bit_load_32x1_to_shared_16x2_layout_8bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 32bit: {a_dtype}") + else: + raise ValueError(f"Unsupported dtype: {e_dtype}") + + thread_binding = self.get_thread_binding() + + @T.macro + def _warp_ldmatrix_e( + E_local_buf, + E_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + for i in T.serial(warp_rows): + # Assign E_shared_buf_elem + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.e_factor + for j in T.serial(local_size_e): + mi, mk = mma_load_layout(tx, j) + E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk] + + return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + warp_k = self.warp_k + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_dtype = self.b_dtype + b_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + replicate_b = self.n_dim == 16 + # ldmatrix cannot be used for int8 + trans case. + ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(b_dtype).bits == 8: + mma_load_layout = mma_sp_load_b_32x32_to_shared_16x64_layout + elif DataType(b_dtype).bits == 16: + mma_load_layout = mma_sp_load_b_32x16_to_shared_16x32_layout + elif DataType(b_dtype).bits == 32: + mma_load_layout = mma_sp_load_b_32x8_to_shared_16x16_layout + else: + raise ValueError(f"Unsupported dtype: {b_dtype}") + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + stride = B_shared_buf.shape[-1] + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + trans = not b_transposed + + for i in T.serial(warp_cols): + # Assign B_shared_elem + wi, wk = ( + warp_n * warp_col_tiles + i * micro_size_y, + rk * warp_k + ki * micro_size_k, + ) + + if ldmatrix_available: + B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi] + + if replicate_b: + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4, + ".b16", + B_local_buf.data, + i * local_size_b, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed), + ) + + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4, + ".b16", + B_local_buf.data, + i * local_size_b + lift(local_size_b) // 2, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed), + ) + else: + T.ptx_ldmatrix( + b_dtype, + T.bool(trans), + 4, + ".b16", + B_local_buf.data, + i * local_size_b, + T.address_of(B_shared_buf_elem), + get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed), + ) + + else: + # load 16x32 data from shared buffer to local buffer + # must be transposed. + for j in T.serial(local_size_b): + mi, mk = mma_load_layout(tx, j) + B_local_buf[i * local_size_b + j] = ( + B_shared_buf[wi + mi, wk + mk] if b_transposed else B_shared_buf[wk + mk, wi + mi] + ) + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + + def mma_sp(self, A_local_buf: Buffer, E_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr = 0): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_e = self.local_size_e + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + replicate_b = self.n_dim == 16 + + a_is_fragment = is_fragment(A_local_buf) + e_is_fragment = is_fragment(E_local_buf) + b_is_fragment = is_fragment(B_local_buf) + assert not e_is_fragment, f"currently E_local_buf must be a local allocation, found {E_local_buf.scope()}" + a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 + e_local_stride: PrimExpr = k_inner * warp_rows * local_size_e if e_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + + @T.macro + def _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma_sp( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + E_local_buf.data, # metadata + e_local_stride + i * local_size_e, # metadata offset + self.SPARSE_SELECTOR, # sparse_selector + T.bool(False), # saturate + ) + if replicate_b: + T.ptx_mma_sp( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + a_local_stride + i * local_size_a, + B_local_buf.data, + b_local_stride + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + E_local_buf.data, # metadata + e_local_stride + i * local_size_e, # metadata offset + self.SPARSE_SELECTOR, # sparse_selector + T.bool(False), # saturate + ) + + return _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf) + + def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, n_dim = self.M_DIM, self.n_dim + C_buf_dims = len(C_buf.shape) + assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + + thread_binding = self.get_thread_binding() + + # STS + # MMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + if C_buf_dims == 2: + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] + else: + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, + (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] + + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + assert matrix in ["A", "B"], "matrix should be either A or B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" + dtype = self.a_dtype if matrix_is_a else self.b_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed if matrix_is_a else self.b_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x16_to_mma_sp_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_sp_layout_sr_b + elif dtype_bits == 16: + transform_func_sr_a = shared_16x32_to_mma_sp_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_sp_layout_sr_b + elif dtype_bits == 8: + transform_func_sr_a = shared_16x64_to_mma_sp_layout_sr_a + transform_func_sr_b = shared_16x64_to_mma_sp_layout_sr_b + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) + is_sr_axis_order = any(is_sr_conditions) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + else: + raise ValueError(f"Unsupported matrix {matrix}") + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + if matrix_is_a: + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + else: + micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] + if is_sr_axis_order + else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.warp_k + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // micro_size_r + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of warp_i and warp_j are warp_rows and warp_cols + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y + mma_i, mma_j = i % micro_size_x, j % micro_size_y + _, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j]) + return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) diff --git a/tilelang/original/tilelang/intrinsics/mmac_layout.py b/tilelang/original/tilelang/intrinsics/mmac_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..33b83b224a181c67b44a395fe438d33b120860e0 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mmac_layout.py @@ -0,0 +1,4 @@ +def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id): + i = thread_id % 16 + j = local_id + (thread_id // 16) * 4 + return i, j \ No newline at end of file diff --git a/tilelang/original/tilelang/intrinsics/mmac_macro_generator.py b/tilelang/original/tilelang/intrinsics/mmac_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3cf51823c7ed16c669f2a1c06fe9c9a0045062 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/mmac_macro_generator.py @@ -0,0 +1,710 @@ +from __future__ import annotations +from tilelang import tvm as tvm +import tilelang.language as T +from tvm import DataType +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad +from tvm.runtime import convert +from .utils import ( + mmac_store_index_map, +) +from tilelang.utils import is_fragment +from .mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_4x16_to_local_64x1_layout_B, + shared_16x16_to_local_64x4_layout_A, + shared_16x16_to_local_64x4_layout_B, + shared_16x32_to_local_64x8_layout_A, + shared_16x32_to_local_64x8_layout_B, + shared_16x64_to_local_64x16_layout_A, + shared_16x64_to_local_64x16_layout_B, + thread_id_shared_access_64x1_to_16x4_layout_A, + thread_id_shared_access_64x1_to_4x16_layout_B, + thread_id_shared_access_64x4_to_16x16_layout_A, + thread_id_shared_access_64x4_to_16x16_layout_B, + thread_id_shared_access_64x8_to_16x32_layout_A, + thread_id_shared_access_64x8_to_16x32_layout_B, + thread_id_shared_access_64x16_to_16x64_layout_A, + thread_id_shared_access_64x16_to_16x64_layout_B, +) + +lift = convert + + +class MatrixCoreIntrinEmitter: + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + N_DIM = 16 + WARP_SIZE = 64 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "float8_e4m3": "e4m3", + "float8_e5m2": "e5m2", + "float8_e4m3fnuz": "e4m3fnuz", + } + + # k_pack represents the number of elements in a vectorized instruction + # Detail information can be found in the triton documentation + # https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419 + k_pack = 1 + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + k_pack: int | None = None, + is_m_first: bool | None = False, + b_preshuffle: bool | None = False, + thread_var: Var | None = None, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_mmac_prefix(self.k_dim) + self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self._initialize_k_pack(k_pack) + self._initialize_is_m_first(is_m_first) + self._initialize_b_preshuffle(b_preshuffle) + + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + self.reduce_k = reduce_k + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var + + def _initialize_k_dim(self, a_dtype=T.float16): + if isinstance(a_dtype, str): + if a_dtype in ["float8_e4m3fnuz", T.int8]: + self.k_dim = 32 + return + a_dtype = DataType(a_dtype) + + if a_dtype.bits == 32: + self.k_dim = 4 + elif a_dtype.bits in {16, 8}: + self.k_dim = 16 + else: + raise ValueError(f"Unsupported a_dtype = {a_dtype}") + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mmac_prefix(self, k_dim=16): + in_dtype, out_dtype = self.a_dtype, self.accum_dtype + M_DIM, N_DIM = self.M_DIM, self.N_DIM + out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] + + in_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "bfloat16": "bf16"}[in_dtype] + + self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" + + def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): + self.micro_size_x = m_dim + self.micro_size_y = n_dim + self.micro_size_k = k_dim + + def _initialize_k_pack(self, k_pack: int | None = None): + if k_pack is not None: + self.k_pack = k_pack + + def _initialize_is_m_first(self, is_m_first: bool | None = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False): + if b_preshuffle is not None: + self.b_preshuffle = b_preshuffle + + def get_ldmatrix_index_map(self, is_b=False): + k_dim = self.k_dim * self.k_pack + transposed = self.a_transposed if not is_b else self.b_transposed + if k_dim == 4: + index_map = shared_16x4_to_local_64x1_layout_A + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A + if is_b: + index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B + reverse_index_map = ( + thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + ) + elif k_dim == 16: + index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + ) + + if is_b: + index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + ) + elif k_dim == 32: + index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + ) + + if is_b: + index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B + ) + elif k_dim == 64: + index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A + ) + + if is_b: + index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B + ) + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + return index_map, reverse_index_map + + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + index_map = IndexMap.from_func(mmac_store_index_map, index_dtype=T.int32) + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + k_pack = self.k_pack + is_transposed = self.a_transposed + thread_binding = self.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + + # legalize shared buffer to region + A_region = self._legalize_to_buffer_region(A_shared_buf) + A_buf = A_region.buffer + A_base0 = A_region.region[-2].min + A_base1 = A_region.region[-1].min + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + else: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0): + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + k_pack = self.k_pack + is_transposed = self.b_transposed + thread_binding = self.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + + # legalize shared buffer to region + B_region = self._legalize_to_buffer_region(B_shared_buf) + B_buf = B_region.buffer + B_base0 = B_region.region[-2].min + B_base1 = B_region.region[-1].min + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id)) + l, r = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * (k_pack * micro_size_k), + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id)) + l, r = ( + rk * chunk + ki * (k_pack * micro_size_k), + warp_n * warp_col_tiles + j * micro_size_y, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + + def mmac(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + k_pack = self.k_pack + mmac_suffix = self.mmac_suffix + a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype + compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}" + compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" + compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" + + a_is_fragment = is_fragment(A_local_buf) + b_is_fragment = is_fragment(B_local_buf) + a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0 + + @T.macro + def _warp_mmac(A_local_buf, B_local_buf, C_local_buf): + for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): + T.tvm_mmac( + mmac_suffix, + "row", + "row", + compute_a_dtype, + compute_b_dtype, + compute_out_dtype, + A_local_buf.data, + (a_local_stride + (j * k_pack + kp) * local_size_a) // local_size_a, + B_local_buf.data, + (b_local_stride + (i * k_pack + kp) * local_size_b) // local_size_b, + C_local_buf.data, + (i * warp_cols * local_size_out + j * local_size_out) // local_size_out, + dtype=compute_out_dtype, + ) + + return _warp_mmac(A_local_buf, B_local_buf, C_local_buf) + + def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + thread_binding = self.get_thread_binding() + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, N_DIM = self.M_DIM, self.N_DIM + C_buf_dims = len(C_buf.shape) + assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + + # STS + # MMAC Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id in T.vectorized(local_size_out): + row, col = T.meta_var(mmac_store_index_map(tx, local_id)) + if C_buf_dims == 2: + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[ + j * (warp_rows * local_size_out) + i * local_size_out + local_id + ] + else: + C_buf[warp_n * warp_cols + j, warp_m * warp_rows + i, row, col] = C_local_buf[ + j * warp_rows * local_size_out + i * local_size_out + local_id + ] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + for i, j in T.grid(warp_rows, warp_cols): + for local_id in T.vectorized(local_size_out): + row, col = T.meta_var(mmac_store_index_map(tx, local_id)) + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] + + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) + + def make_mmac_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMAC results into a fragment buffer. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object describing the thread and index layout for MMAC. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + inverse_mmac_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Map fragment row `i` and column `j` to a thread index. + """ + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + mmac_i, mmac_j = i % micro_size_x, j % micro_size_y + + lane_id, _ = inverse_mmac_store_layout.map_indices([mmac_i, mmac_j]) + + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Map fragment row `i` and column `j` to a local index within a thread's registers. + """ + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + mmac_i, mmac_j = i % micro_size_x, j % micro_size_y + + # 使用 MMAC 的底层硬件逆映射获取局部偏移 + _, local_id = inverse_mmac_store_layout.map_indices([mmac_i, mmac_j]) + + return warp_j * (warp_rows * local_size_out) + warp_i * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + """ + Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + + - Buffer -> full-region BufferRegion covering entire shape + - BufferRegion -> returned as-is + - BufferLoad -> best-effort convert via get_buffer_region_from_load; + if scalar, fall back to 1-sized ranges at given indices + """ + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + # Fallback: scalar load -> 1-sized ranges at indices + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") + +class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): + def __init__( + self, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + k_pack: int | None = None, + is_m_first: bool | None = False, + a_preshuffle: bool | None = False, + b_preshuffle: bool | None = False, + thread_var: Var | None = None, + ): + super().__init__( + a_dtype=a_dtype, + b_dtype=b_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + num_elems_per_byte=num_elems_per_byte, + k_pack=k_pack, + is_m_first=is_m_first, + thread_var=thread_var, + ) + self._initialize_preshuffle(a_preshuffle, b_preshuffle) + + def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): + if a_preshuffle is not None: + self.a_preshuffle = a_preshuffle + if b_preshuffle is not None: + self.b_preshuffle = b_preshuffle + + def ldmatrix_a(self, A_local_buf, A_buf, ki, rk=0, pid_m=None, pid_n=None): + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + k_pack = self.k_pack + is_transposed = self.a_transposed + current_frame = T.KernelLaunchFrame.Current() + thread_binding = current_frame.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + is_global = pid_m is not None and pid_n is not None + + # no preshuffle, use the default implementation + if self.a_preshuffle is False: + return super().ldmatrix_a(A_local_buf, A_buf, ki, rk) + + def _warp_ldmatrix_a_global( + A_local_buf, + A_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + (pid_m * self.block_row_warps + warp_m) * warp_rows + i, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col] + else: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + (pid_m * self.block_row_warps + warp_m) * warp_rows + i, + rk * (chunk // micro_size_k) + ki, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col] + + @T.macro + def _warp_ldmatrix_a_shared( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_m * warp_rows + i, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] + else: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] + + return ( + _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk) + ) + + def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + k_pack = self.k_pack + is_transposed = self.b_transposed + current_frame = T.KernelLaunchFrame.Current() + thread_binding = current_frame.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + is_global = pid_m is not None and pid_n is not None + + if self.b_preshuffle is False: + return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n) + + @T.macro + def _warp_ldmatrix_b_global( + B_local_buf, + B_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + (pid_n * self.block_col_warps + warp_n) * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + (pid_n * self.block_col_warps + warp_n) * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col] + + @T.macro + def _warp_ldmatrix_b_shared( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id)) + l, r = ( + warp_n * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_n * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] + + return ( + _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk) + ) diff --git a/tilelang/original/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/original/tilelang/intrinsics/tcgen05_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..923bb0e1067b9ce51a7d36163fb0ec15b3c97b27 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/tcgen05_macro_generator.py @@ -0,0 +1,446 @@ +from __future__ import annotations +from enum import IntEnum +import tilelang.language as T +from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter +from tvm import DataType +from tvm.tir import PrimExpr, Buffer, Var, BufferLoad, BufferRegion +from tilelang import tvm as tvm +from tilelang import _ffi_api +from tilelang.utils import is_tensor_memory +from tilelang.layout import ( + Layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, +) +from tvm.runtime import convert + +lift = convert + + +class SwizzleMode(IntEnum): + # SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + NONE = 0 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + + def is_none(self) -> bool: + return self == SwizzleMode.NONE + + def is_swizzle_32b(self) -> bool: + return self == SwizzleMode.SWIZZLE_32B + + def is_swizzle_64b(self) -> bool: + return self == SwizzleMode.SWIZZLE_64B + + def is_swizzle_128b(self) -> bool: + return self == SwizzleMode.SWIZZLE_128B + + def swizzle_byte_size(self) -> int: + if self.is_swizzle_32b(): + return 32 + elif self.is_swizzle_64b(): + return 64 + elif self.is_swizzle_128b(): + return 128 + else: + return 1 + + def swizzle_atom_size(self) -> int: + if self.is_swizzle_32b(): + return 32 // 16 + elif self.is_swizzle_64b(): + return 64 // 16 + elif self.is_swizzle_128b(): + return 128 // 16 + else: + return 1 + + +# derive from MMAIntrinEmitter as some layouts are the same +class TensorCoreIntrinEmitter(MMAIntrinEmitter): + """ + To eliminate Python syntax within TIR Macro. + """ + + # should be rewritten to support dynamic k_dim + tcgen05_prefix: str + + a_shared_layout: Layout = None + b_shared_layout: Layout = None + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool = False, + thread_var: Var | None = None, + ): + super().__init__( + a_dtype, + b_dtype, + accum_dtype, + a_transposed, + b_transposed, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + reduce_k, + num_elems_per_byte, + is_m_first, + thread_var, + ) + + def _assign_a_shared_layout(self, layout: Layout): + self.a_shared_layout = layout + return self + + def _assign_b_shared_layout(self, layout: Layout): + self.b_shared_layout = layout + return self + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + # For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32 + assert warp_row_tiles >= 8, f"warp_row_tiles must be greater than 8, got {warp_row_tiles}" + assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + # four warps per block + self.warp_rows = warp_row_tiles // 8 + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + self.micro_size_k = k_dim + + def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: + # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper + if layout is None or layout.is_equal(make_linear_layout(buffer)): + return SwizzleMode.NONE + elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_32B + elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_64B + elif layout.is_equal(make_full_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_128B + else: + raise ValueError(f"Unsupported swizzle mode: {layout}") + + def tcgen05mma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, clear_accum: PrimExpr = False): + if is_tensor_memory(A_buf): + return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum) + + accum_dtype = self.accum_dtype + m_dim = self.block_row_warps * self.warp_row_tiles + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + + elems_in_bits = DataType(self.a_dtype).bits + elems_in_bytes = elems_in_bits // 8 + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + accum_dtype_in_bits = DataType(accum_dtype).bits + + meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) + if len(meta) != 5: + raise ValueError( + f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " + f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) + atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) + + # by default, we utilize non-swizzle layout offset + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) + + if not a_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + else: + # MN Major + # LBO represents the distance between two atoms along the M dimension + # SBO represents the distance between two atoms along the K dimension + a_m_axis_atoms = m_dim // a_swizzle_atom_elems + if a_m_axis_atoms <= 1: + a_leading_byte_offset = 0 + else: + a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size() + + if a_m_axis_atoms <= 1: + a_stride_byte_offset = 8 * elems_in_bytes * m_dim + else: + a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + # MN Major, K * N + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + # for example, if [n, k] where k is 128, we should split it into 2 atoms + # where max specially handles the case when n_dim is 8. + ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + + instr_desc = self.get_tcgen5_instr_desc( + atom_m, + atom_n, + atom_k, + a_is_k_major, + b_is_k_major, + scale_in_a, + scale_in_b, + ) + # Allocate an instruction descriptor wrapper and initialize it + a_dtype_abbrv = self.a_dtype_abbrv + mask_zero = T.Cast(T.int32, 0) + mask0 = mask1 = mask2 = mask3 = mask_zero + + # TCGEN05 only has one warp group + num_inst_m = self.block_row_warps * self.warp_row_tiles // atom_m + num_inst_n = self.block_col_warps * self.warp_col_tiles // atom_n + + # Helper to allow BufferRegion/BufferLoad as inputs + def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): + if isinstance(buffer_or_load_or_region, Buffer): + return buffer_or_load_or_region.access_ptr(access_type) + elif isinstance(buffer_or_load_or_region, BufferLoad): + buffer_load = buffer_or_load_or_region + offset, stride = 0, 1 + buffer = buffer_load.buffer + for i, shape in enumerate(reversed(buffer.shape)): + indice = buffer_load.indices[len(buffer_load.indices) - i - 1] + if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): + offset += indice * stride + elif isinstance(indice, tvm.tir.Ramp): + offset += indice.base * stride + else: + raise ValueError(f"Unsupported index type: {type(indice)}") + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + elif isinstance(buffer_or_load_or_region, BufferRegion): + buffer_region = buffer_or_load_or_region + buffer = buffer_region.buffer + offset, stride = 0, 1 + for i, shape in enumerate(reversed(buffer.shape)): + offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + else: + raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + + @T.macro + def _warp_mma(A_buf, B_buf, C_local_buf, mbar): + # Allocate SMEM descriptors for A and B + desc_a = T.alloc_tcgen05_smem_desc() + desc_b = T.alloc_tcgen05_smem_desc() + A_ptr = access_ptr_from(A_buf, "r") + B_ptr = access_ptr_from(B_buf, "r") + + T.initialize_tcgen05_descriptor( + desc_a, + A_ptr, + int(a_leading_byte_offset >> 4), + int(a_stride_byte_offset >> 4), + 0, + False, + int(a_swizzle_mode), + ) + T.initialize_tcgen05_descriptor( + desc_b, + B_ptr, + int(b_leading_byte_offset >> 4), + int(b_stride_byte_offset >> 4), + 0, + False, + int(b_swizzle_mode), + ) + + tmem_col_step = atom_n // (128 // atom_m) + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(0, (k_dim // micro_size_k)): + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + A_elem_offset = ( + (ki % ak_atom_size) * micro_size_k + + i * atom_m * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) + + B_elem_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + j * atom_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + + A_byte_offset = A_elem_offset * elems_in_bytes + B_byte_offset = B_elem_offset * elems_in_bytes + C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 # 32 bits per tmem bank + + T.ptx_tcgen05_mma_ss( + a_dtype_abbrv, + desc_a.data, + A_byte_offset, + desc_b.data, + B_byte_offset, + C_local_buf.data, + C_offset, + instr_desc, + scale_out, + mask0, + mask1, + mask2, + mask3, + enable_ws, + ) + T.tcgen05_mma_arrive(mbar) + + return _warp_mma(A_buf, B_buf, C_local_buf, mbar) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: + raise NotImplementedError + + def make_mma_store_layout(self, tmem_buf: Buffer) -> Layout: + """ + Create the TCGEN5 tensor-memory layout used to store MMA accumulators. + + Parameters + ---------- + tmem_buf : tir.Buffer + The local buffer representing tensormemory of a mma's output + + Returns + ------- + Layout + Layout object describing how logical (i, j) coordinates map to the + swizzled tensor-memory offsets required by TCGEN5MMA. + + Raises + ------ + AssertionError + If `tmem_buf` is not detected to be a tensor-memory buffer. + """ + assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)" + if len(tmem_buf.shape) != 2: + raise ValueError(f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}") + + m = int(tmem_buf.shape[0]) + n = int(tmem_buf.shape[1]) + k = int(self.chunk) + + meta = self.get_tcgen5_mma_meta(m, n, k) + if len(meta) != 5: + raise ValueError( + f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) + atom_m, atom_n, _, _, _ = (int(x) for x in meta) + + if m % atom_m != 0 or n % atom_n != 0: + raise ValueError(f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})") + + def forward(i: PrimExpr, j: PrimExpr): + atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m) + ai = i % atom_m + aj = j % atom_n + + if atom_m == 128: + # Layout D + return [ + ai, + aj + atom_idx * atom_n, + ] + if atom_m == 64: + # Layout E (.ws variant) + half_atom_n = atom_n // 2 + return [ + (ai // 32) * 32 + ai % 32 + (aj // half_atom_n) * 64, + (aj % half_atom_n) + atom_idx * half_atom_n, + ] + if atom_m == 32: + # Layout G + quarter_atom_n = atom_n // 4 + return [ + ai % 32 + (aj // quarter_atom_n) * 32, + (aj % quarter_atom_n) + atom_idx * quarter_atom_n, + ] + + raise ValueError(f"Unsupported TCGEN5 atom_m={atom_m}") + + return Layout([m, n], forward) + + def get_tcgen5_mma_meta(self, m: int, n: int, k: int): + return _ffi_api.get_tcgen5_mma_meta(int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype)) + + def get_tcgen5_instr_desc( + self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int + ) -> PrimExpr: + desc = _ffi_api.get_tcgen5_instr_desc( + atom_m, + atom_n, + atom_k, + DataType(self.a_dtype), + DataType(self.accum_dtype), + a_is_k_major, + b_is_k_major, + scale_in_a, + scale_in_b, + ) + return lift(desc) diff --git a/tilelang/original/tilelang/intrinsics/utils.py b/tilelang/original/tilelang/intrinsics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..475087a723f0ba2b8dfcbb4d1d4ce0040cd66133 --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/utils.py @@ -0,0 +1,117 @@ +from tvm import DataType +from typing import Literal +from .mma_layout import ( + ldmatrix_32x4_to_shared_16x8_layout_a, + ldmatrix_32x4_to_shared_16x8_layout_b, + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, + mma_store_32x8_to_shared_16x16_layout, + mma_store_32x2_to_shared_8x8_layout_fp64, +) +from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m + +from .mma_layout import get_swizzle_layout # noqa: F401 +from .mma_layout import make_mma_swizzle_layout # noqa: F401 +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 + + +# the original implementation and insight is from the following code snippet +# 3rdparty/tvm/python/tvm/tir/tensor_intrin/cuda.py#get_ldmatrix_intrin +def get_ldmatrix_offset( + matrix: Literal["A", "B"], + row_idx, + col_idx, + stride, + dtype: Literal["float16", "int8"] = "float16", + transposed: bool = False, +): + assert matrix in ["A", "B"], "matrix should be either A or B" + dtype_bits = DataType(dtype).bits + if dtype_bits == 32: + if matrix == "B" and transposed: + transform_func = ldmatrix_32x4_to_shared_16x8_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_32x4_to_shared_16x8_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") + elif dtype_bits == 16: + transform_func = ldmatrix_32x8_to_shared_16x16_layout + transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout + if transposed: + new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif dtype_bits == 8: + if matrix == "B" and transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") + else: + raise ValueError(f"Unsupported dtype {dtype}") + + +def shared_16x16_to_mma_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + + +def shared_16x32_to_mma_32x16_layout(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 + + +def shared_32x16_to_mma_32x16_layout(i, j): + thread_id = (i % 16) // 4 + 4 * (j % 8) + return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 + + +def mma_store_index_map(thread_id, local_id): + return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id) + + +def mma_store_index_map_fp64(thread_id, local_id): + return mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id) + + +def mfma_store_index_map(thread_id, local_id): + return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) + +def mmac_store_index_map(thread_id, local_id): + return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) + + +def get_mma_micro_size(dtype: Literal["float16", "int8"]): + # TODO(lei): FP8 related precision support. + # Basic Tensor Core Matrix Multiply operation Unit + """ + Return the MMA (Tensor Core) micro-tile dimensions for a given data type. + + This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations. + - x: tile width in the output/result dimension + - y: tile height in the output/result dimension + - k: tile depth in the reduction/K dimension + + Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16. + + Returns: + tuple[int, int, int]: (micro_size_x, micro_size_y, micro_size_k) + """ + micro_size_x = micro_size_y = 16 + micro_size_k = 16 + if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: + micro_size_k = 32 + return micro_size_x, micro_size_y, micro_size_k diff --git a/tilelang/original/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/original/tilelang/intrinsics/wgmma_macro_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..864420c77174c34b5b75165bbe1398dcc96e4dac --- /dev/null +++ b/tilelang/original/tilelang/intrinsics/wgmma_macro_generator.py @@ -0,0 +1,612 @@ +from __future__ import annotations +import tilelang.language as T +from enum import IntEnum +from typing import Callable +from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter +from tvm import DataType +from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferRegion +from tilelang.utils import is_fragment, retrive_ptr_from_buffer_region, is_full_region +from math import gcd +from tilelang.layout import ( + Layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, +) +from tvm.runtime import convert +from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, +) + +lift = convert + + +class SwizzleMode(IntEnum): + # SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + NONE = 0 + SWIZZLE_128B = 1 + SWIZZLE_64B = 2 + SWIZZLE_32B = 3 + + def is_none(self) -> bool: + return self == SwizzleMode.NONE + + def is_swizzle_32b(self) -> bool: + return self == SwizzleMode.SWIZZLE_32B + + def is_swizzle_64b(self) -> bool: + return self == SwizzleMode.SWIZZLE_64B + + def is_swizzle_128b(self) -> bool: + return self == SwizzleMode.SWIZZLE_128B + + def swizzle_byte_size(self) -> int: + if self.is_swizzle_32b(): + return 32 + elif self.is_swizzle_64b(): + return 64 + elif self.is_swizzle_128b(): + return 128 + else: + return 1 + + def swizzle_atom_size(self) -> int: + if self.is_swizzle_32b(): + return 32 // 16 + elif self.is_swizzle_64b(): + return 64 // 16 + elif self.is_swizzle_128b(): + return 128 // 16 + else: + return 1 + + +# derive from MMAIntrinEmitter as some layouts are the same +class TensorCoreIntrinEmitter(MMAIntrinEmitter): + """ + To eliminate Python syntax within TIR Macro. + """ + + # should be rewritten to support dynamic k_dim + wgmma_prefix: str + + # wgmma instruction M dimension + wgmma_inst_m: int + # wgmma instruction N dimension + wgmma_inst_n: int + + a_shared_layout: Layout = None + b_shared_layout: Layout = None + + def __init__( + self, + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool | None = False, + thread_var: Var | None = None, + ): + super().__init__( + a_dtype, + b_dtype, + accum_dtype, + a_transposed, + b_transposed, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + reduce_k, + num_elems_per_byte, + is_m_first, + thread_var, + ) + self._initialize_wgmma_prefix(self.n_dim) + + def _assign_a_shared_layout(self, layout: Layout): + self.a_shared_layout = layout + return self + + def _assign_b_shared_layout(self, layout: Layout): + self.b_shared_layout = layout + return self + + def _initialize_wgmma_prefix(self, n_dim: int = 16): + inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256) + assert inst_n % 8 == 0, ( + f"inst_n must be a multiple of 8, got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) + # Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8 + assert 8 <= inst_n <= 256, ( + f"inst_n must be within [8, 256], got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) + # 256 bits per instruction + inst_k = 256 // DataType(self.a_dtype).bits + self.wgmma_inst_m = inst_m + self.wgmma_inst_n = inst_n + self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}" + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + # four warps per block + self.warp_rows = warp_row_tiles // m_dim + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + self.micro_size_k = k_dim + + def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: + # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper + if layout is None or layout.is_equal(make_linear_layout(buffer)): + return SwizzleMode.NONE + elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_32B + elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_64B + elif layout.is_equal(make_full_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_128B + else: + raise ValueError(f"Unsupported swizzle mode: {layout}") + + def wgmma( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): + if is_fragment(A_region): + return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait) + + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + m_dim = self.block_row_warps * self.warp_row_tiles + warp_cols = self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + wgmma_prefix = self.wgmma_prefix + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + + a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + + elems_in_bits = DataType(self.a_dtype).bits + elems_in_bytes = elems_in_bits // 8 + + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + accum_bits = DataType(accum_dtype).bits + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 + + # by default, we utilize non-swizzle layout offset + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) + + if not a_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + else: + # MN Major + # LBO represents the distance between two atoms along the M dimension + # SBO represents the distance between two atoms along the K dimension + a_m_axis_atoms = m_dim // a_swizzle_atom_elems + if a_m_axis_atoms <= 1: + a_leading_byte_offset = 0 + else: + a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + + if a_m_axis_atoms <= 1: + a_stride_byte_offset = 8 * elems_in_bytes * m_dim + else: + a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + # MN Major, K * N + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + # for example, if [n, k] where k is 128, we should split it into 2 atoms + # where max specially handles the case when n_dim is 8. + ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n + num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m + num_inst_n = self.warp_col_tiles // wgmma_inst_n + + thread_binding = self.get_thread_binding() + + A_ptr = retrive_ptr_from_buffer_region(A_region) + B_ptr = retrive_ptr_from_buffer_region(B_region) + assert is_full_region(C_region), "Fragment output C must be a full region" + + C_buf = C_region.buffer + + @T.macro + def _warp_mma(A_ptr, B_ptr, C_buf): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + + desc_a = T.alloc_wgmma_desc() + desc_b = T.alloc_wgmma_desc() + T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + T.warpgroup_arrive() + + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(k_dim // micro_size_k): + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + warp_i = (warp_m // 4) * num_inst_m + i + warp_j = warp_n * num_inst_n + j + A_offset = ( + (ki % ak_atom_size) * micro_size_k + + warp_i * 64 * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit + T.ptx_wgmma_ss( + accum_dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + desc_a.data, + (A_offset * elems_in_bytes) >> 4, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + T.warpgroup_commit_batch() + if wg_wait >= 0: + T.warpgroup_wait(wg_wait) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + + return _warp_mma(A_ptr, B_ptr, C_buf) + + def wgmma_rs( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): + local_size_a = self.local_size_a + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + m_dim = self.block_row_warps * self.warp_row_tiles + warp_rows, warp_cols = self.warp_rows, self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + wgmma_prefix = self.wgmma_prefix + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_bits = DataType(self.a_dtype).bits + accum_bits = DataType(accum_dtype).bits + a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32 + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 + b_is_k_major = self.b_transposed + + b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + # MN Major + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n + num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m + num_inst_n = self.warp_col_tiles // wgmma_inst_n + + thread_binding = self.get_thread_binding() + + assert is_full_region(A_region), "Fragment input A must be a full region" + assert is_full_region(C_region), "Fragment output C must be a full region" + A_buf = A_region.buffer + B_ptr = retrive_ptr_from_buffer_region(B_region) + C_buf = C_region.buffer + + @T.macro + def _warp_mma(A_buf, B_ptr, C_buf): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + + desc_b = T.alloc_wgmma_desc() + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + T.warpgroup_arrive() + + for j in T.unroll(0, num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(0, (k_dim // micro_size_k)): + warp_j = warp_n * num_inst_n + j + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + + A_offset = ki * warp_rows * local_size_a + i * local_size_a + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit + T.ptx_wgmma_rs( + accum_dtype, + wgmma_prefix, + self.b_transposed, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf.data, + A_offset, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + T.warpgroup_commit_batch() + if wg_wait >= 0: + T.warpgroup_wait(wg_wait) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) + + return _warp_mma(A_buf, B_ptr, C_buf) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + assert matrix in ["A"], "matrix should be A for WGMMA" + dtype = self.a_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + elif dtype_bits == 8: + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(not transposed) + is_sr_axis_order = any(is_sr_conditions) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows = self.warp_rows + chunk = self.chunk + + warp_s = warp_rows + warp_r = chunk // micro_size_r + block_s = block_row_warps + replicate = block_col_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) + else: + # rs condition, transposed_a matrix + warp_fragment = base_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + lane_id, _ = inverse_mma_store_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + _, local_id = inverse_mma_store_layout.map_indices([i, j]) + return local_id + + # reproduce src/layout/gemm_layouts.cc::makeGemmFragmentCHopper + base_fragment = T.Fragment( + [micro_size_x, micro_size_y], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + warp_n_layout = base_fragment.repeat([1, warp_cols], False, False) + block_layout = warp_n_layout.repeat([block_row_warps, block_col_warps], True, False) + warp_m_layout = block_layout.repeat([warp_rows, 1], False, False) + return warp_m_layout diff --git a/tilelang/original/tilelang/ir.py b/tilelang/original/tilelang/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a7de5ebb22fd43d7f8e23813966e72f9d8ca2a --- /dev/null +++ b/tilelang/original/tilelang/ir.py @@ -0,0 +1,76 @@ +from tilelang import tvm as tvm +from tvm.ir.base import Node +from tvm.runtime import Scriptable +import tvm_ffi +from tvm.target import Target +from tilelang import _ffi_api + + +@tvm_ffi.register_object("tl.Fill") +class Fill(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.AtomicAdd") +class AtomicAdd(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.Copy") +class Copy(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.Conv2DIm2Col") +class Conv2DIm2ColOp(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.GemmWarpPolicy") +class GemmWarpPolicy(Node, Scriptable): + policy_type: int + m_warp: int + n_warp: int + + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool): + _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma) + return self.m_warp, self.n_warp + + +@tvm_ffi.register_object("tl.GemmSPWarpPolicy") +class GemmSPWarpPolicy(Node, Scriptable): + policy_type: int + m_warp: int + n_warp: int + + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool, bits: int): + _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma, bits) + return self.m_warp, self.n_warp + + +@tvm_ffi.register_object("tl.Gemm") +class Gemm(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.GemmSP") +class GemmSP(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.FinalizeReducerOp") +class FinalizeReducerOp(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.ParallelOp") +class ParallelOp(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.ReduceOp") +class ReduceOp(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.CumSumOp") +class CumSumOp(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.RegionOp") +class RegionOp(Node, Scriptable): ... + + +@tvm_ffi.register_object("tl.ReduceType") +class ReduceType(Node, Scriptable): ... diff --git a/tilelang/original/tilelang/jit/__init__.py b/tilelang/original/tilelang/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eac206f7245a62929f418daba245329d07dd127e --- /dev/null +++ b/tilelang/original/tilelang/jit/__init__.py @@ -0,0 +1,569 @@ +""" +This module provides an auto-tuning infrastructure for TileLang (tl) programs. +It includes functionality to JIT-compile TileLang programs into a runnable +kernel adapter using TVM. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import inspect +from typing import ( + Any, + Callable, + Generic, + TypeVar, + overload, + Literal, +) +from collections.abc import Iterable + +# Python 3.9 compatibility for ParamSpec +try: + from typing import ParamSpec +except ImportError: # Python < 3.10 + from typing_extensions import ParamSpec +from tilelang import tvm as tvm +from tilelang.language.v2 import PrimFunc, PrimFuncCreater, prim_func +from tilelang.language.v2.annot import Annot +from tvm.target import Target + +from tilelang.jit.kernel import JITKernel +from tilelang.utils.target import determine_target +from tilelang.cache import cached +from os import path, makedirs +from logging import getLogger +from tilelang.jit.param import Kernel +import concurrent.futures + +from tqdm.auto import tqdm + +logger = getLogger(__name__) + +_P = ParamSpec("_P") +_KP = ParamSpec("_KP") +_T = TypeVar("_T") +_Ret = TypeVar("_Ret") + + +def compile( + func: PrimFunc[_KP, _T] = None, + out_idx: list[int] | int | None = None, + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto", + target: str | Target = "auto", + target_host: str | Target | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, +) -> JITKernel[_KP, _T]: + """ + Compile the given TileLang PrimFunc with TVM and build a JITKernel. + Parameters + ---------- + func : tvm.tir.PrimFunc, optional + The TileLang TIR function to compile and wrap. + out_idx : Union[List[int], int], optional + Index(es) of the output tensors to return (default: None). + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional + Execution backend to use for kernel execution. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). + target : Union[str, Target], optional + Compilation target, either as a string or a TVM Target object (default: "auto"). + target_host : Union[str, Target], optional + Target host for cross-compilation (default: None). + verbose : bool, optional + Whether to enable verbose output (default: False). + pass_configs : dict, optional + Additional keyword arguments to pass to the Compiler PassContext. + Refer to `tilelang.transform.PassConfigKey` for supported options. + """ + + assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" + + if hasattr(func, "out_idx_override"): + if func.out_idx_override is not None and out_idx is not None: + raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors") + out_idx = func.out_idx_override or out_idx + + # This path is not a performance critical path, so we can afford to convert the target. + target = Target(determine_target(target)) + + # Resolve execution backend (handles aliases, auto, validation per target) + requested_backend = execution_backend + from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + + execution_backend = resolve_execution_backend(requested_backend, target) + if verbose: + allowed_now = allowed_backends_for_target(target, include_unavailable=False) + logger.info( + "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", + execution_backend, + requested_backend, + target.kind.name, + ", ".join(sorted(allowed_now)), + ) + + return cached( + func=func, + out_idx=out_idx, + execution_backend=execution_backend, + target=target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + +def par_compile( + funcs: Iterable[PrimFunc[_KP, _T]], + out_idx: list[int] | int | None = None, + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto", + target: str | Target = "auto", + target_host: str | Target | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, + num_workers: int = None, + ignore_error: bool = False, +) -> list[JITKernel[_KP, _T]]: + """ + Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. + Parameters + ---------- + funcs : Iterable[tvm.tir.PrimFunc] + The TileLang TIR functions to compile and wrap. + out_idx : Union[List[int], int], optional + Index(es) of the output tensors to return (default: None). + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional + Execution backend to use for kernel execution. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). + target : Union[str, Target], optional + Compilation target, either as a string or a TVM Target object (default: "auto"). + target_host : Union[str, Target], optional + Target host for cross-compilation (default: None). + verbose : bool, optional + Whether to enable verbose output (default: False). + pass_configs : dict, optional + Additional keyword arguments to pass to the Compiler PassContext. + Refer to `tilelang.transform.PassConfigKey` for supported options. + """ + with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor: + futures = [] + future_map = {} + for i, func in enumerate(funcs): + future = executor.submit( + compile, + func=func, + out_idx=out_idx, + execution_backend=execution_backend, + target=target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + future_map[future] = i + futures.append(future) + results = [... for _ in futures] + for future in tqdm( + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Parallel Compiling", + ): + idx = future_map[future] + if ignore_error: + try: + results[idx] = future.result() + except Exception as e: + logger.warning(f"Error compiling function at index {idx}: {e}") + results[idx] = None + else: + results[idx] = future.result() + return results + return results + + +@dataclass +class JITImpl(Generic[_P, _KP, _T, _Ret]): + """ + Detailed Just-In-Time wrapper for TileLang programs. + + This dataclass encapsulates the configuration and runtime helpers used by the + top-level `jit` and `jit2` decorators. It represents a configured JIT + "factory" that can (a) elaborate TileLang/PrimFunc creators into concrete + TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via + the TVM bridge, (c) cache compiled kernels keyed by call-site arguments + (and optional tuning parameters), and (d) provide parallel compilation + helpers for batch autotuning workflows. + + Attributes + ---------- + out_idx : list[int] | int | None + Which output tensor(s) of the compiled kernel should be returned to the + caller. Accepts a single index, a list of indices, or None to return all. + execution_backend : Literal["dlpack", "ctypes", "cython"] + Backend used for exchanging arguments and executing the generated kernel. + target : str | tvm.target.Target + TVM compilation target (e.g. "cuda", "llvm", or "auto"). + target_host : str | tvm.target.Target | None + Host target used for cross-compilation, or None to infer/default. + verbose : bool + Enable verbose messages during compilation/build. + pass_configs : dict[str, Any] | None + Extra TVM pass configuration options forwarded to the compiler's + PassContext. + debug_root_path : str | None + If provided, compiled kernel source and the elaborated Python program + are written to this directory to ease debugging and inspection. + compile_flags : list[str] | str | None + Additional flags passed to the compiler. A single string will be converted + to a single-element list. + func_source : str + Original Python source string from which the PrimFunc or creator was + derived. Used for diagnostics and debug dumps. + signature : inspect.Signature + Function signature of the original Python function (useful for tooling). + v2 : bool + Indicates whether the object wraps a "v2" PrimFunc creator (True) or a + plain callable / PrimFunc (False). v2-mode enables argument conversion + hooks and a distinct cache keying strategy. + func : Callable | PrimFunc | PrimFuncCreater + The underlying object: either a user function that returns a PrimFunc + (creator), a PrimFuncCreater, or an already-constructed PrimFunc. + For presentation/readability the function is stored last in the dataclass. + + Behavioral summary + ------------------ + - get_tir(*args, **kwargs) + Converts provided call-site arguments into a concrete PrimFunc. If the + wrapped object is a PrimFuncCreater or a user callable, it is invoked + with the given arguments. If the wrapped object is already a PrimFunc, + it is returned as-is. + + - compile(...) + A convenience wrapper that elaborates and immediately compiles a single + PrimFunc into a JITKernel using the module-level `compile` function. + When `debug_root_path` is set, the compiled C kernel and the source + Python program are saved for inspection. + + - par_compile(configs, ...) + Accepts an iterable of configs (either dicts mapping keyword args or + tuples mapping to positional args). Each config is elaborated to a + PrimFunc and the resulting set is compiled in parallel via the + module-level `par_compile` helper. Returns a list of JITKernel objects + in the same order as the provided configs. + """ + + out_idx: list[int] | int | None + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] + target: str | Target + target_host: str | Target + verbose: bool + pass_configs: dict[str, Any] | None + debug_root_path: str | None + compile_flags: list[str] | str | None + func_source: str + signature: inspect.Signature + lazy_jit: bool + # place func at the last element for better __repr__ + func: Callable[_P, _T] | PrimFunc[_KP, _T] + + @property + def annot(self) -> dict[str, Annot]: + assert self.lazy_jit, "annot is only support in @tilelang.jit2" + return self.func.func_annot.annots + + def __post_init__(self): + if self.debug_root_path is not None and not path.isabs(self.debug_root_path): + try: + base_path = path.dirname(path.dirname(path.dirname(__file__))) + self.debug_root_path = path.join(base_path, self.debug_root_path) + except NameError: + self.debug_root_path = path.abspath(self.debug_root_path) + self._kernel_cache: dict[tuple, Kernel] = {} + self._tuner_cache: dict[tuple, Kernel] = {} + + def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]: + """ + Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object. + """ + if isinstance(self.func, PrimFuncCreater): + tir = self.func(*args, **kwargs) + elif isinstance(self.func, PrimFunc): + tir = self.func + elif callable(self.func): + tir = self.func(*args, **kwargs) + else: + raise ValueError(f"Invalid function type: {type(self.func)}") + assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" + return tir + + def par_compile( + self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False + ) -> list[JITKernel[_KP, _T]]: + """ + Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. + Parameters + ---------- + configs : Iterable[Union[dict[str, Any], tuple[Any, ...]]] + The configurations to elaborate and compile. Each config can be either + a dictionary mapping keyword arguments to values, or a tuple of positional + arguments. + num_workers : int, optional + Number of parallel workers to use for compilation. Defaults to None, + which lets the system decide. + ignore_error : bool, optional + If True, compilation errors for individual configs will be logged + as warnings and the corresponding result will be None. If False, + any compilation error will raise an exception. Defaults to False. + Returns + ------- + List[JITKernel] + A list of compiled JITKernel objects corresponding to the provided configs. + """ + configs = list(configs) + funcs = [] + for cfg in tqdm(configs, desc="Elaborating"): + if isinstance(cfg, tuple): + funcs.append(self.get_tir(*cfg)) + elif isinstance(cfg, dict): + funcs.append(self.get_tir(**cfg)) + else: + raise ValueError(f"Invalid config type: {type(cfg)}, expected tuple or dict.") + return par_compile( + funcs, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + num_workers=num_workers, + ignore_error=ignore_error, + ) + + def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: + func = self.get_tir(*args, **kwargs) + kernel_result = compile( + func, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) + + if self.debug_root_path: + if isinstance(self.func, PrimFunc): + func_name = self.func.attrs["global_symbol"] + else: + func_name = getattr(self.func, "__name__", "jit_kernel") + kernel_file = f"tilelang_jit_kernel_{func_name}.c" + program_file = f"tilelang_jit_program_{func_name}.py" + makedirs(self.debug_root_path, exist_ok=True) + with open(path.join(self.debug_root_path, kernel_file), "w") as f: + print(kernel_result.get_kernel_source(), file=f) + with open(path.join(self.debug_root_path, program_file), "w") as f: + print(func.script(), file=f) + + return kernel_result + + def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs): + if isinstance(self.func, PrimFuncCreater): + tune_params = kwargs.pop("__tune_params", {}) + return self.func.func_annot.parse_key(*args, **kwargs, **tune_params) + else: + tune_params = kwargs.pop("__tune_params", {}) + key_args_tuple = args + key_kwargs_tuple = tuple(sorted(kwargs.items())) + tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) + key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) + return key + + def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs): + if isinstance(self.func, PrimFuncCreater): + tune_params = kwargs.pop("__tune_params", {}) + return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params) + else: + raise NotImplementedError("convert_arg_to_kernel_args is only implemented for PrimFuncCreater.") + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: + # Separate out the tuning parameters from the user's kwargs + # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache + return_compile_arguments = kwargs.pop("__return_compile_arguments", False) + if return_compile_arguments: + logger.warning("`__return_compile_arguments` is deprecated and will be removed in future versions.") + compile_args = { + "out_idx": self.out_idx, + "execution_backend": self.execution_backend, + "target": self.target, + "target_host": self.target_host, + "verbose": self.verbose, + "pass_configs": self.pass_configs, + "compile_flags": self.compile_flags, + } + return compile_args + + key = self.parse_cache_key(*args, **kwargs) + + tune_params = kwargs.pop("__tune_params", {}) + + kernel = self._kernel_cache.get(key, None) + if kernel is None: + kernel = self.compile(*args, **kwargs, **tune_params) + self._kernel_cache[key] = kernel + + if self.lazy_jit: + args = self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params) + return kernel(*args) + else: + return kernel + + +ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] + + +@overload +def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ... + + +@overload +def jit( + *, # Indicates subsequent arguments are keyword-only + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: ExecutionBackend = "auto", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None, +) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ... + + +def jit( # This is the new public interface + func: Callable[_P, _T] | PrimFunc | None = None, + *, # Indicates subsequent arguments are keyword-only + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: ExecutionBackend = "auto", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None, +): + """ + Just-In-Time (JIT) compiler decorator for TileLang functions. + + This decorator can be used without arguments (e.g., `@tilelang.jit`): + Applies JIT compilation with default settings. + + Parameters + ---------- + func_or_out_idx : Any, optional + If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. + If using `@tilelang.jit` directly on a function, this argument is implicitly + the function to be decorated (and `out_idx` will be `None`). + target : Union[str, Target], optional + Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". + target_host : Union[str, Target], optional + Target host for cross-compilation. Defaults to None. + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional + Backend for kernel execution and argument passing. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). + verbose : bool, optional + Enables verbose logging during compilation. Defaults to False. + pass_configs : Optional[Dict[str, Any]], optional + Configurations for TVM's pass context. Defaults to None. + debug_root_path : Optional[str], optional + Directory to save compiled kernel source for debugging. Defaults to None. + + Returns + ------- + Callable + Either a JIT-compiled wrapper around the input function, or a configured decorator + instance that can then be applied to a function. + """ + + def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: + if isinstance(func, (PrimFunc, PrimFuncCreater)): + orig_func = func.orig_func + else: + orig_func = func + return JITImpl( + func=func, + out_idx=out_idx, + execution_backend=execution_backend, + target=target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + debug_root_path=debug_root_path, + compile_flags=compile_flags, + func_source=inspect.getsource(orig_func), + signature=inspect.signature(orig_func), + lazy_jit=False, + ) + + if func is not None: + return decorator(func) + else: + return decorator + + +@overload +def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... + + +@overload +def lazy_jit( + *, + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: ExecutionBackend = "auto", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None, +) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ... + + +def lazy_jit( + func: Callable[_P, _T] | PrimFunc | None = None, + *, # Indicates subsequent arguments are keyword-only + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: ExecutionBackend = "auto", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None, +): + compile_args = dict( + out_idx=None, + execution_backend=execution_backend, + target=target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + debug_root_path=debug_root_path, + compile_flags=compile_flags, + ) + + def decorator(func: Callable[_P, _T]): + pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True) + # if isinstance(pf, PrimFunc): + # compile_args.pop('debug_root_path', None) + # return compile(pf, **compile_args) + # else: + return JITImpl( + func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True + ) + + return decorator(func) if func is not None else decorator diff --git a/tilelang/original/tilelang/jit/adapter/__init__.py b/tilelang/original/tilelang/jit/adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f511608fcd71334907d5cce7e1425109914dce70 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/__init__.py @@ -0,0 +1,7 @@ +from .base import BaseKernelAdapter # noqa: F401 +from .tvm_ffi import TVMFFIKernelAdapter # noqa: F401 +from .ctypes import CtypesKernelAdapter # noqa: F401 +from .cython import CythonKernelAdapter # noqa: F401 +from .nvrtc import NVRTCKernelAdapter # noqa: F401 +from .torch import MetalKernelAdapter # noqa: F401 +from .cutedsl import CuTeDSLKernelAdapter # noqa: F401 diff --git a/tilelang/original/tilelang/jit/adapter/base.py b/tilelang/original/tilelang/jit/adapter/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3669f9e35c6f0d667d6389adf23a1caf0109b865 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/base.py @@ -0,0 +1,96 @@ +"""The profiler and convert to torch utils""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable +from tilelang.engine.param import KernelParam +import torch + + +class BaseKernelAdapter(ABC): + func: Callable | None = None + + def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None: + self.mod = mod + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self._post_init() + + def _legalize_result_idx(self, result_idx: list[int] | None) -> list[int]: + params = self.params + # result_idx is a list of indices of the output tensors + if result_idx is None: + result_idx = [] + elif isinstance(result_idx, int): + if result_idx > len(params) or result_idx < -len(params): + raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}") + if result_idx < 0: + result_idx = len(params) + result_idx + result_idx = [result_idx] + elif isinstance(result_idx, list): + for i, idx in enumerate(result_idx): + if idx >= len(params) or idx < -len(params): + raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}") + if idx < 0: + result_idx[i] = len(params) + idx + else: + raise ValueError("result_idx should be a list of integers") + + return result_idx + + @abstractmethod + def _convert_torch_func(self) -> callable: + pass + + # --- Common helpers to align with PyTorch stream/device semantics --- + @staticmethod + def get_current_stream_functor() -> Callable[[], int]: + """Return a callable that reads Torch's current CUDA stream pointer. + + The returned lambda yields the raw CUDA stream handle of the current + PyTorch stream on the active device. It's a thunk (evaluated at call + time) so that any upstream stream guards are respected. If CUDA is + unavailable, it returns a lambda that yields 0. + """ + if torch.cuda.is_available(): + try: + torch.cuda._lazy_init() + current_device = torch._C._cuda_getDevice + get_stream = torch._C._cuda_getCurrentRawStream + return lambda: get_stream(current_device()) + except Exception: + # Fallback to Python API if internal handles are unavailable + return lambda: int(torch.cuda.current_stream().cuda_stream) + # CPU or CUDA unavailable: no stream semantics + return lambda: 0 + + @staticmethod + def get_current_device_functor() -> Callable[[], torch.device]: + """Return a callable that yields Torch's current device. + + Similar to the stream functor, we capture a callable that, when called, + fetches the current device according to PyTorch. On CPU or when CUDA is + unavailable, returns ``torch.device('cpu')``. + """ + if torch.cuda.is_available(): + try: + torch.cuda._lazy_init() + current_device = torch._C._cuda_getDevice + return lambda: torch.device("cuda", current_device()) + except Exception: + return lambda: torch.device("cuda", torch.cuda.current_device()) + # CPU fallback + return lambda: torch.device("cpu") + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.func(*args, **kwds) + + def get_kernel_source(self, kernel_only: bool = True) -> str: + if kernel_only: + return self.mod.imports[0].inspect_source() + else: + return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source() + + def _post_init(self): + self.func = self._convert_torch_func() diff --git a/tilelang/original/tilelang/jit/adapter/ctypes/__init__.py b/tilelang/original/tilelang/jit/adapter/ctypes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6fdc84d6c71da609de4bfcdcc0bf0d9dceeb7d --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/ctypes/__init__.py @@ -0,0 +1 @@ +from .adapter import CtypesKernelAdapter # noqa: F401 diff --git a/tilelang/original/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/original/tilelang/jit/adapter/ctypes/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cac9d6fd049bc2db5875f09de59f97cee3f116 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/ctypes/adapter.py @@ -0,0 +1,302 @@ +"""The profiler and convert to torch utils""" + +from __future__ import annotations +import torch +from ..base import BaseKernelAdapter +import ctypes +from typing import Callable, Any +from tilelang import tvm as tvm +from tvm.target import Target +from tvm.relax import TensorType +from tvm import tir +from tilelang.jit.adapter.wrapper import TLWrapper +from tilelang.jit.adapter.libgen import LibraryGenerator +from tilelang.utils.target import determine_target +from tilelang.utils.language import retrieve_func_from_module + + +# TODO(lei): remove ctypes adapter. +class CtypesKernelAdapter(BaseKernelAdapter): + """Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes. + + This adapter handles: + 1. Converting TIR functions to compiled CUDA libraries + 2. Managing dynamic shapes in tensor operations + 3. Wrapping C++ kernels for Python/PyTorch usage + """ + + # Class attributes to store compiled kernel information + target = "cuda" + ir_module: tvm.IRModule | None = None + # The global source code of the kernel -> global means the source code of the kernel + # that is not wrapped by the wrapper code + host_kernel_source: str | None = None + device_kernel_source: str | None = None + lib: ctypes.CDLL | None = None # Compiled library handle + # Maps symbolic variables to their corresponding buffer and shape indices + dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None + # Pass configs for the compiler + pass_configs: dict[str, Any] | None = None + + # Add new cache attributes + param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes + param_shapes: list[list] | None = None # Cache for parameter shapes + + def __init__( + self, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + """Initialize the adapter with the given TIR function or module. + + Args: + params: List of tensor types for inputs/outputs + result_idx: Indices of output tensors + target: Target platform (e.g., 'cuda') + func_or_mod: TIR function or module to be compiled + verbose: Enable verbose logging + """ + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.host_kernel_source = host_kernel_source + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + self.ir_module = func_or_mod + + # Cache parameter information during initialization + # Convert tvm.DataType to torch.dtype for tensor creation + self.param_dtypes = [param.torch_dtype() for param in params] + self.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + native_shape.append(dim) # Keep tir.Var for dynamic dimensions + else: + native_shape.append(dim) + self.param_shapes.append(native_shape) + + self.dynamic_symbolic_map = self._process_dynamic_symbolic() + + self.target = Target.canon_target(determine_target(target)) + self.verbose = verbose + self.wrapper = TLWrapper(self.target) + self.lib_generator = LibraryGenerator(self.target, verbose=verbose) + self.lib_generator.assign_pass_configs(pass_configs) + self.lib_generator.assign_compile_flags(compile_flags) + + self.wrapper.assign_optimized_module(self.ir_module) + self.wrapper.assign_pass_configs(pass_configs) + self.wrapper.assign_host_module(host_mod) + self.wrapper.assign_device_module(device_mod) + self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) + + self.lib_generator.update_lib_code(self.wrapped_source) + self.lib_generator.compile_lib() + self.lib = self.lib_generator.load_lib() + self.lib.init() + + self._post_init() + + @classmethod + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source + adapter.pass_configs = pass_configs + + if isinstance(func_or_mod, tir.PrimFunc): + adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + # Cache parameter information during initialization + # Convert tvm.DataType to torch.dtype for tensor creation + adapter.param_dtypes = [param.torch_dtype() for param in params] + adapter.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + native_shape.append(dim) # Keep tir.Var for dynamic dimensions + else: + native_shape.append(dim) + adapter.param_shapes.append(native_shape) + + adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic() + + adapter.target = Target.canon_target(determine_target(target)) + adapter.verbose = verbose + adapter.lib_generator = LibraryGenerator(adapter.target, verbose=verbose) + adapter.lib_generator.assign_pass_configs(pass_configs) + adapter.lib_generator.assign_compile_flags(compile_flags) + adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path) + adapter.lib.init() + + adapter._post_init() + return adapter + + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]: + """Extract information about dynamic shapes from the TIR function. + + Maps symbolic variables to their corresponding (id, buffer_index, dimension) + for runtime shape resolution. + id represents shape or stride, 0 represents shape, 1 represents stride + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map = {} + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): + dynamic_symbolic_map[shape] = (0, i, j) + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, stride in enumerate(buffer.strides): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): + dynamic_symbolic_map[stride] = (1, i, j) + return dynamic_symbolic_map + + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): + """Low-level function to call the compiled CUDA kernel. + + Converts PyTorch tensor pointers to C void pointers for ctypes interface. + """ + ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args] + ctypes_args.append(ctypes.c_void_p(stream)) + self.lib.call(*ctypes_args) + + def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): + """High-level wrapper for kernel execution. + + Handles: + 1. Input validation + 2. Output tensor allocation + 3. Dynamic shape resolution + 4. CUDA stream management + + Args: + ins: Input PyTorch tensors + stream: Optional CUDA stream for asynchronous execution + + Returns: + Single tensor or list of tensors containing the kernel results + """ + if len(ins) + len(self.result_idx) != len(self.params): + raise ValueError( + f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" + ) + ins_idx = 0 + args = [] + + # tensor pointers + for i in range(len(self.params)): + if i in self.result_idx: + dtype = self.param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in self.param_shapes[i]: + if isinstance(s, tir.Var): + ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s] + shape.append(ins[ref_tensor_idx].shape[ref_shape_idx]) + else: # Already converted to Python int during initialization + shape.append(s) + device = ins[0].device if len(ins) > 0 else torch.cuda.current_device() + tensor = torch.empty(*shape, dtype=dtype, device=device) + else: + tensor = ins[ins_idx] + ins_idx += 1 + args.append(tensor) + + # dynamic symbolics + for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): + if ref_id == 0: + args.append(ins[buffer_idx].shape[shape_idx]) + else: + args.append(ins[buffer_idx].stride(shape_idx)) + + # if stream is not None, we need to pass the stream to the library + if stream is None: + if str(self.target).startswith("cuda") and torch.cuda.is_available(): + stream = torch.cuda.current_stream().cuda_stream + else: + stream = 0 + + self._forward_from_prebuild_lib(*args, stream=stream) + + if len(self.result_idx) == 1: + return args[self.result_idx[0]] + else: + return [args[i] for i in self.result_idx] + + def _convert_torch_func(self) -> Callable: + """Returns a PyTorch-compatible function wrapper for the kernel.""" + return self._wrap_forward_from_prebuild_lib + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) + + @property + def srcpath(self): + """Returns the source path of the compiled library.""" + return self.lib_generator.srcpath + + @property + def libpath(self): + """Returns the path to the compiled library.""" + return self.lib_generator.libpath + + @property + def lib_code(self): + """Returns the code of the compiled library.""" + return self.lib_generator.lib_code + + @property + def is_dynamic(self): + """Indicates whether the kernel handles dynamic shapes.""" + return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0 + + def get_kernel_source(self, kernel_only: bool = False): + """Returns the source code of the compiled kernel.""" + if kernel_only: + return self.device_kernel_source + else: + # Wrapper only has host kernel source + return self.host_kernel_source diff --git a/tilelang/original/tilelang/jit/adapter/cutedsl/__init__.py b/tilelang/original/tilelang/jit/adapter/cutedsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e25899a1d0e2e4f279b87d5c48e5b40d430e8679 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cutedsl/__init__.py @@ -0,0 +1,16 @@ +"""CuTeDSL Backend for TileLang. + +This module provides runtime compilation support using NVIDIA's CuTeDSL API. +""" + +__all__ = [ + "CuTeDSLKernelAdapter", + "TLCuTeDSLSourceWrapper", + "CuTeDSLLibraryGenerator", + "check_cutedsl_available", +] + +from .checks import check_cutedsl_available # noqa: F401 +from .adapter import CuTeDSLKernelAdapter # noqa: F401 +from .wrapper import TLCuTeDSLSourceWrapper # noqa: F401 +from .libgen import CuTeDSLLibraryGenerator # noqa: F401 diff --git a/tilelang/original/tilelang/jit/adapter/cutedsl/adapter.py b/tilelang/original/tilelang/jit/adapter/cutedsl/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ab5db4d26ddcfed8bef2223da310b3e86eda97 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cutedsl/adapter.py @@ -0,0 +1,368 @@ +from __future__ import annotations +import logging +from typing import Any, Callable + +import torch +from tvm import tir +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.engine.param import KernelParam +from tilelang.jit.adapter.wrapper import TLPyWrapper +from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available +from tilelang.jit.adapter.cutedsl.libgen import CuTeDSLLibraryGenerator +from tilelang.utils.language import retrieve_func_from_module +from tilelang.utils.target import determine_target +from tilelang.jit.adapter.base import BaseKernelAdapter + +logger = logging.getLogger(__name__) + + +class CuTeDSLKernelAdapter(BaseKernelAdapter): + pymodule = None + + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + check_cutedsl_available() + + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.host_kernel_source = host_kernel_source + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + gsym = func_or_mod.attrs.get("global_symbol") + if gsym is None: + raise ValueError("PrimFunc is missing required attr 'global_symbol'") + self.ir_module = tvm.IRModule({gsym: func_or_mod}) + else: + self.ir_module = func_or_mod + + # Cache parameter information during initialization + self.param_dtypes = [param.torch_dtype() for param in params] + self.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + # Keep tir.Var for dynamic dimensions + native_shape.append(dim) + else: + native_shape.append(dim) + self.param_shapes.append(native_shape) + + self.dynamic_symbolic_map, self.dynamic_symbolic_order = self._process_dynamic_symbolic() + + self.target = Target.canon_target(determine_target(target)) + self.verbose = verbose + self.wrapper = TLPyWrapper(self.target) + self.wrapper.assign_optimized_module(self.ir_module) + self.wrapper.assign_pass_configs(pass_configs) + self.wrapper.assign_host_module(host_mod) + self.wrapper.assign_device_module(device_mod) + wrapper_result = self.wrapper.wrap(device_kernel_source) + self.host_func = wrapper_result["host_func"] + self.function_names = wrapper_result["function_names"] + self.tma_cpp_init_code = wrapper_result["tma_cpp_init_code"] + self.tma_lib_name = wrapper_result["tma_lib_name"] + self.launcher_cpp_code = wrapper_result.get("launcher_cpp_code", None) + self.launcher_lib_name = wrapper_result.get("launcher_lib_name", None) + + self.lib_generator = CuTeDSLLibraryGenerator(self.target, self.verbose) + self.lib_generator.update_lib_code(self.device_kernel_source) + self.lib_generator.update_host_func(self.host_func) + self.lib_generator.update_tma_cpp_init_code(self.tma_cpp_init_code) + self.lib_generator.update_tma_lib_name(self.tma_lib_name) + self.lib_generator.update_launcher_cpp_code(self.launcher_cpp_code) + self.lib_generator.update_launcher_lib_name(self.launcher_lib_name) + self.lib_generator.assign_compile_flags(compile_flags) + self.lib_generator.compile_lib() + self.lib_generator.load_lib() + self.libpath = self.lib_generator.libpath + self.device_kernel_source = open(self.libpath).read() + self.pymodule = self.lib_generator.pymodule + + self._post_init() + + @classmethod + def from_database( + cls, + params: list[KernelParam], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + gsym = func_or_mod.attrs.get("global_symbol") + if gsym is None: + raise ValueError("PrimFunc is missing required attr 'global_symbol'") + adapter.ir_module = tvm.IRModule({gsym: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + # Cache parameter information during initialization + adapter.param_dtypes = [param.torch_dtype() for param in params] + adapter.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + # Keep tir.Var for dynamic dimensions + native_shape.append(dim) + else: + native_shape.append(dim) + adapter.param_shapes.append(native_shape) + + adapter.dynamic_symbolic_map, adapter.dynamic_symbolic_order = adapter._process_dynamic_symbolic() + + adapter.target = Target.canon_target(determine_target(target)) + adapter.verbose = verbose + adapter.lib_generator = CuTeDSLLibraryGenerator(adapter.target, adapter.verbose) + adapter.lib_generator.assign_compile_flags(compile_flags) + adapter.lib_generator.load_lib(lib_path=kernel_lib_path) + adapter.libpath = kernel_lib_path + adapter.kernel_global_source = open(adapter.libpath).read() + adapter.pymodule = adapter.lib_generator.pymodule + + adapter._post_init() + return adapter + + def _process_dynamic_symbolic(self) -> tuple[dict[tir.Var, tuple[int, int, int]], list[tir.Var]]: + """Extract information about dynamic symbols from the TIR function. + + We follow the same ordering semantics as `TLCUDASourceWrapper.get_dynamic_symbolic_set()`: + 1) dynamic symbols in buffer shapes (in prim_func param order) + 2) then dynamic symbols in buffer strides + + The mapping encodes: + - id=0: shape var -> (0, buffer_param_index, dim_index) + - id=1: stride var -> (1, buffer_param_index, stride_index) + + Returns: + (dynamic_symbolic_map, dynamic_symbolic_order) + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] = {} + dynamic_symbolic_order: list[tir.Var] = [] + + def unique_push_back(v: tir.Var, entry: tuple[int, int, int]): + if v in dynamic_symbolic_map: + return + dynamic_symbolic_map[v] = entry + dynamic_symbolic_order.append(v) + + # 1) Shapes + for i, param in enumerate(params): + if param not in buffer_map: + continue + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if isinstance(shape, tir.Var): + unique_push_back(shape, (0, i, j)) + + # 2) Strides + for i, param in enumerate(params): + if param not in buffer_map: + continue + buffer = buffer_map[param] + if buffer.strides is None: + continue + for j, stride in enumerate(buffer.strides): + if isinstance(stride, tir.Var): + unique_push_back(stride, (1, i, j)) + + return dynamic_symbolic_map, dynamic_symbolic_order + + def get_kernel_source(self, kernel_only: bool = True) -> str | None: + """Get the CUDA kernel source code. + + Returns + ------- + str | None + The kernel source code, or None if not available + """ + return self.device_kernel_source + + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): + """Low-level function to call the compiled CUDA kernel.""" + result = self.pymodule.call(*args, stream=stream) + + # After first call, save cubin to cache if needed + self._save_cubin_to_cache_if_needed() + + return result + + def _save_cubin_to_cache_if_needed(self): + """Save cubin to cache directory after first execution. + + This is called after the first kernel execution to ensure the generated + cubin file is copied to the cache directory for future reuse. + """ + if getattr(self, "_cubin_saved_to_cache", False): + return + self._cubin_saved_to_cache = True + + # Check if we have a cache path (set by kernel_cache) + cache_path = getattr(self, "_cache_path", None) + if cache_path is None: + return + + import os + import shutil + + # Source cubin path (in temp directory) + src_py_path = self.libpath + src_py_stem = os.path.splitext(os.path.basename(src_py_path))[0] + src_dir = os.path.dirname(src_py_path) + src_cubin_path = os.path.join(src_dir, f"{src_py_stem}.cubin") + + if not os.path.exists(src_cubin_path): + return + + # Destination cubin path (in cache directory) + dst_cubin_path = os.path.join(cache_path, "kernel.cubin") + + if os.path.exists(dst_cubin_path): + return + + # Copy cubin to cache + try: + shutil.copy2(src_cubin_path, dst_cubin_path) + logger.debug(f"Saved CuTeDSL cubin to cache: {dst_cubin_path}") + except Exception as e: + logger.warning(f"Failed to save cubin to cache: {e}", exc_info=True) + + def _wrap_forward_from_prebuild_lib(self, *ins: Any, stream: int | None = None): + """High-level wrapper for kernel execution. + + Handles: + 1. Input validation + 2. Output tensor allocation + 3. Dynamic shape resolution + 4. CUDA stream management + + Args: + ins: Input arguments (may include scalars and tensors) + stream: Optional CUDA stream for asynchronous execution + + Returns: + Single tensor or list of tensors containing the kernel results + """ + if len(ins) + len(self.result_idx) != len(self.params): + raise ValueError( + f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" + ) + + # Materialize args in PrimFunc param order (inputs + allocated outputs) + ins_idx = 0 + param_values: list[Any] = [None] * len(self.params) + for i in range(len(self.params)): + if i in self.result_idx: + continue + param_values[i] = ins[ins_idx] + ins_idx += 1 + + first_tensor = next((v for v in param_values if isinstance(v, torch.Tensor)), None) + if first_tensor is None: + raise ValueError("Expected at least one torch.Tensor argument to infer CUDA device") + + args: list[Any] = [] + + # tensor pointers + for i in range(len(self.params)): + if i in self.result_idx: + dtype = self.param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in self.param_shapes[i]: + if isinstance(s, tir.Var): + ref_id, ref_param_idx, ref_dim_idx = self.dynamic_symbolic_map[s] + ref_val = param_values[ref_param_idx] + if not isinstance(ref_val, torch.Tensor): + raise TypeError(f"Dynamic shape/stride var {s} refers to a non-tensor param at index {ref_param_idx}") + if ref_id == 0: + shape.append(ref_val.shape[ref_dim_idx]) + elif ref_id == 1: + # Stride vars are not expected in output shapes, but handle defensively. + shape.append(ref_val.stride()[ref_dim_idx]) + else: + raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}") + else: # Already converted to Python int during initialization + shape.append(s) + tensor = torch.empty(*shape, dtype=dtype, device=first_tensor.device) + param_values[i] = tensor + else: + tensor = param_values[i] + args.append(tensor) + + # dynamic symbolics + for sym in self.dynamic_symbolic_order: + ref_id, buffer_idx, dim_idx = self.dynamic_symbolic_map[sym] + ref_val = param_values[buffer_idx] + if not isinstance(ref_val, torch.Tensor): + raise TypeError(f"Dynamic symbolic var {sym} refers to a non-tensor param at index {buffer_idx}") + if ref_id == 0: + args.append(ref_val.shape[dim_idx]) + elif ref_id == 1: + args.append(ref_val.stride()[dim_idx]) + else: + raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}") + + # if stream is not None, we need to pass the stream to the library + if stream is None: + if str(self.target).startswith("cuda") and torch.cuda.is_available(): + stream = torch.cuda.current_stream().cuda_stream + else: + stream = 0 + + self._forward_from_prebuild_lib(*args, stream=stream) + + if len(self.result_idx) == 1: + return args[self.result_idx[0]] + else: + return [args[i] for i in self.result_idx] + + def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]: + """Convert to a PyTorch-compatible function. + + Returns + ------- + Callable[..., torch.Tensor | list[torch.Tensor]] + A callable function that takes tensors and returns tensor(s) + """ + return self._wrap_forward_from_prebuild_lib + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) diff --git a/tilelang/original/tilelang/jit/adapter/cutedsl/checks.py b/tilelang/original/tilelang/jit/adapter/cutedsl/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..ced8ea7c30bc3836e36d1b7bfd833b28fb811e5a --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cutedsl/checks.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +from importlib import metadata as _importlib_metadata +from importlib.util import find_spec as _find_spec +import os + +_CUTEDSL_PUBLIC_DIST = "nvidia-cutlass-dsl" +_CUTEDSL_MIN_VERSION = (4, 3, 1) +_VERSION_TRIPLE_RE = re.compile(r"(\d+)\.(\d+)\.(\d+)") + + +def _parse_version_triple(version_str: str) -> tuple[int, int, int] | None: + """Parse a best-effort (major, minor, patch) triple from a version string. + + We intentionally avoid importing heavy/optional version parsers. For our + minimum requirement (>= 4.3.1), a numeric triple comparison is sufficient. + """ + m = _VERSION_TRIPLE_RE.search(version_str) + if not m: + return None + return int(m.group(1)), int(m.group(2)), int(m.group(3)) + + +def _min_version_str() -> str: + return ".".join(map(str, _CUTEDSL_MIN_VERSION)) + + +def _requirement_spec() -> str: + return f"{_CUTEDSL_PUBLIC_DIST}>={_min_version_str()}" + + +def check_cutedsl_available() -> None: + """Fail fast if the CuTeDSL backend cannot be used in this Python environment. + + Policy: + - If the public distribution `nvidia-cutlass-dsl` is installed, require version >= a minimum supported version. + - Regardless of distribution metadata, require that `cutlass.cute` is importable. + + This intentionally does not mention or special-case any internal distributions. + """ + # 1) Version gate (only when the public dist metadata is present) + try: + dist_version = _importlib_metadata.version(_CUTEDSL_PUBLIC_DIST) + except _importlib_metadata.PackageNotFoundError: + dist_version = None + except Exception: + # Metadata is best-effort; don't block internal/nonstandard installs here. + dist_version = None + + if dist_version is not None: + parsed = _parse_version_triple(dist_version) + if parsed is None or parsed < _CUTEDSL_MIN_VERSION: + req = _requirement_spec() + raise ImportError( + f"CuTeDSL backend requires `{req}`, but found version `{dist_version}`. Please run: `pip install -U '{req}'`." + ) + + # 2) Capability probe: keep it cheap. + # Importing cutlass/cute can be expensive and defeats our lazy-import design, + # especially on cache hits. We only require that the module is importable. + cutlass_spec = _find_spec("cutlass") + if cutlass_spec is None: + req = _requirement_spec() + raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).") + + # Avoid find_spec("cutlass.cute") which can be surprisingly expensive. + # Instead, check for a 'cute' submodule/package under cutlass's search locations. + locs = getattr(cutlass_spec, "submodule_search_locations", None) + has_cute = False + if locs: + for base in locs: + if os.path.isdir(os.path.join(base, "cute")) or os.path.isfile(os.path.join(base, "cute.py")): + has_cute = True + break + + if not has_cute: + req = _requirement_spec() + raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).") diff --git a/tilelang/original/tilelang/jit/adapter/cutedsl/libgen.py b/tilelang/original/tilelang/jit/adapter/cutedsl/libgen.py new file mode 100644 index 0000000000000000000000000000000000000000..3dac6b141a08b1e53a8159e4a597c4e96c56b1af --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cutedsl/libgen.py @@ -0,0 +1,124 @@ +"""CuTeDSL Library Generator for TileLang. + +This module provides library generation functionality for the CuTeDSL backend. +""" + +from __future__ import annotations +import importlib.util +import os +import tempfile +import subprocess + +from tvm.target import Target + +from tilelang.jit.adapter.libgen import LibraryGenerator +from tilelang.jit.adapter.utils import is_cutedsl_target + + +class CuTeDSLLibraryGenerator(LibraryGenerator): + host_func: str | None = None + tma_cpp_init_code: str | None = None + tma_lib_name: str | None = None + launcher_cpp_code: str | None = None + launcher_lib_name: str | None = None + pymodule = None + + def __init__(self, target: Target, verbose: bool = False): + super().__init__(target, verbose) + + @staticmethod + def import_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def update_host_func(self, host_func: str): + self.host_func = host_func + + def update_tma_cpp_init_code(self, tma_cpp_init_code: str): + self.tma_cpp_init_code = tma_cpp_init_code + + def update_tma_lib_name(self, tma_lib_name: str): + self.tma_lib_name = tma_lib_name + + def update_launcher_cpp_code(self, launcher_cpp_code: str): + self.launcher_cpp_code = launcher_cpp_code + + def update_launcher_lib_name(self, launcher_lib_name: str): + self.launcher_lib_name = launcher_lib_name + + def load_lib(self, lib_path: str | None = None): + if lib_path is None: + if self.libpath is None: + raise RuntimeError("CuTeDSLLibraryGenerator.libpath is not set; call compile_lib() first or pass lib_path explicitly.") + lib_path = self.libpath + + self.pymodule = self.import_from_file("kernel", lib_path) + + def compile_lib(self, timeout: float = None): + if self.host_func is None: + raise RuntimeError("CuTeDSLLibraryGenerator.host_func is not set; call update_host_func() before compile_lib().") + target = self.target + if is_cutedsl_target(target): + # Use a dedicated temp directory per kernel so CuTeDSL artifacts (e.g. kept .cubin) + # never pollute user CWD, and are easy to locate alongside the generated module. + work_dir = tempfile.mkdtemp(prefix="tilelang_cutedsl_") + src_path = os.path.join(work_dir, "kernel.py") + with open(src_path, "w") as f: + # Note: lib_code (containing @cute.kernel definitions) is embedded + # inside host_func's _generate_cubin_if_needed function, so we only + # write host_func here. This ensures cute imports are lazy-loaded. + f.write(self.host_func) + + # Compile C++ launcher library if needed + if self.launcher_cpp_code is not None: + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".cpp", + delete=False, + ) as launcher_src: + launcher_src.write(self.launcher_cpp_code) + launcher_src_path = launcher_src.name + + # Generate launcher lib under the same directory as the source file + launcher_lib_path = os.path.join(os.path.dirname(src_path), self.launcher_lib_name) + + # Get TVM FFI compiler flags using tvm_ffi.libinfo API + try: + import tvm_ffi.libinfo + + include_paths = tvm_ffi.libinfo.include_paths() + tvm_cxxflags = [f"-I{path}" for path in include_paths] + lib_path = tvm_ffi.libinfo.find_libtvm_ffi() + lib_dir = os.path.dirname(lib_path) + tvm_ldflags = [f"-L{lib_dir}", "-ltvm_ffi"] + except (ImportError, RuntimeError): + # tvm_ffi unavailable or libinfo functions failed + tvm_cxxflags = [] + tvm_ldflags = [] + + # Compile with nvcc (need CUDA driver API) + compile_cmd = [ + "nvcc", + "-shared", + "-Xcompiler=-fPIC", + "-lcuda", + *tvm_cxxflags, + *tvm_ldflags, + "-o", + launcher_lib_path, + launcher_src_path, + ] + + result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0: + raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}") + + self.launcher_libpath = launcher_lib_path + self.launcher_libname = self.launcher_lib_name + + self.srcpath = src_path + self.libpath = src_path + else: + raise ValueError(f"Unsupported target: {target}") diff --git a/tilelang/original/tilelang/jit/adapter/cutedsl/wrapper.py b/tilelang/original/tilelang/jit/adapter/cutedsl/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c20d2ec67983aecec5a0270d0cd19ef7dd24d6d9 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cutedsl/wrapper.py @@ -0,0 +1,1354 @@ +"""CuTeDSL Source Wrapper for TileLang. + +This module provides C++ kernel launcher generation for the CuTeDSL backend. + +Key features: +- Automatic C++ launcher generation with CUDA Driver API +- TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed) +- cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space +- Support for single and multiple kernel launches +- Complete cache system integration +""" + +from __future__ import annotations +from typing import Any, ClassVar + +from tvm import IRModule +from tvm.target import Target +from tvm.tir.stmt_functor import post_order_visit + +from tilelang import tvm as tvm +from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper +from tilelang.jit.adapter.utils import ( + extract_python_func_declaration, + pythonic_expr, + parse_tma_descriptor_args, +) + +# ============================================================================= +# C++ LAUNCHER TEMPLATES (using named parameters for clarity) +# ============================================================================= + +# TMA single descriptor initialization template (writes to caller-provided host array) +# No device copy needed - cuLaunchKernel handles __grid_constant__ params automatically +CPP_TMA_DESC_INIT_TEMPLATE = """\ + // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) + {{ + uint64_t globalDim[{rank}] = {{{global_dim_values}}}; + uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}}; + uint32_t boxDim[{rank}] = {{{box_dim_values}}}; + uint32_t elemStrides[{rank}] = {{{elem_stride_values}}}; + + result = cuTensorMapEncodeTiled( + &tma_descs[{desc_idx}], + static_cast({dtype}), + {rank}, + reinterpret_cast({tensor_name}_ptr), + globalDim, + globalStrides, + boxDim, + elemStrides, + static_cast({interleave}), + static_cast({swizzle}), + static_cast({l2_promotion}), + static_cast({oob_fill}) + ); + + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\\n"; + return result; + }} + }} +""" + +# TMA single im2col descriptor initialization template (writes to caller-provided host array) +# Align field ordering with NVRTC wrapper (cuTensorMapEncodeIm2col signature). +CPP_TMA_IM2COL_DESC_INIT_TEMPLATE = """\ + // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col] + {{ + uint64_t globalDim[{rank}] = {{{global_dim_values}}}; + uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}}; + uint32_t elemStrides[{rank}] = {{{elem_stride_values}}}; + int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}}; + int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}}; + + result = cuTensorMapEncodeIm2col( + &tma_descs[{desc_idx}], + static_cast({dtype}), + {rank}, + reinterpret_cast({tensor_name}_ptr), + globalDim, + globalStrides, + lowerCorner, + upperCorner, + static_cast({channels_per_pixel}), + static_cast({pixels_per_column}), + elemStrides, + static_cast({interleave}), + static_cast({swizzle}), + static_cast({l2_promotion}), + static_cast({oob_fill}) + ); + + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\\n"; + return result; + }} + }} +""" + +# TMA initialization function template (writes to caller-provided host array) +# __grid_constant__ allows kernel to receive TMA descriptor by value via param space +CPP_TMA_INIT_FUNC_TEMPLATE = """\ +CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{ + // Initialize {num_descs} TMA descriptor(s) in caller-provided host array + // cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically + CUresult result; + +{desc_init_code} + + return CUDA_SUCCESS; +}} +""" + +# Kernel initialization template +CPP_KERNEL_INIT_TEMPLATE = """\ + // Find and configure kernel {kernel_idx}: {kernel_name} + result = find_kernel_by_pattern(g_module, "{kernel_name}", &g_kernels[{kernel_idx}]); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to find kernel {kernel_name}: " << result << "\\n"; + return result; + }} + + if ({smem_size} > 0) {{ + result = cuFuncSetAttribute(g_kernels[{kernel_idx}], + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + {smem_size}); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to set smem for {kernel_name}: " << result << "\\n"; + return result; + }} + }} +""" + +# TMA launch initialization template (host memory mode - uses __grid_constant__) +# Kernel receives TMA descriptor by value: .param .align 128 .b8 xxx_param[128] +CPP_TMA_LAUNCH_INIT_TEMPLATE = """\ + // Declare stack-local TMA descriptor array (eliminates concurrency race) + CUtensorMap tma_descs[{num_tma_descs}]; + + // Initialize TMA descriptors (HOST memory - passed via __grid_constant__) + // NOTE: We intentionally do NOT reuse/cached descriptors across launches. + // Pointer-only reuse is a correctness trap (shape/stride may change with same ptr), + // and correctness beats micro-optimizations. + result = tma_init(tma_descs, {tma_tensor_args}); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to initialize TMA descriptors: " << result << "\\n"; + return result; + }} +""" + +# Kernel launch template +CPP_KERNEL_LAUNCH_TEMPLATE = """\ + // Launch kernel {kernel_idx}: {kernel_name} + {{ + void* args[] = {{{kernel_args}}}; + result = cuLaunchKernel( + g_kernels[{kernel_idx}], + {grid_x}, {grid_y}, {grid_z}, + {block_x}, {block_y}, {block_z}, + {smem_size}, + stream, + args, + nullptr + ); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to launch kernel {kernel_name}: " << result << "\\n"; + return result; + }} + }} +""" + +# Complete C++ launcher template +CPP_LAUNCHER_TEMPLATE = """\ +#include +#include +#include +#include +#include +#include +#include + +// TVM Headers +#include +#include +#include + +// Cached module handle +static CUmodule g_module = nullptr; +static bool g_module_initialized = false; + +// Cached kernel functions +static CUfunction g_kernels[{num_kernels}] = {{nullptr}}; +static bool g_kernels_initialized = false; + +// Find kernel by pattern (substring match, prefer base name over _N variants) +CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{ + CUresult result; + unsigned int num_funcs = 0; + + result = cuModuleGetFunctionCount(&num_funcs, module); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to get function count: " << result << "\\n"; + return result; + }} + + std::vector func_list(num_funcs); + result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to enumerate functions: " << result << "\\n"; + return result; + }} + + // Collect substring matches, separating base name from _N variants + std::vector> base_matches; // pattern not followed by _digit + std::vector> variant_matches; // pattern followed by _digit + + size_t pattern_len = std::strlen(pattern); + + for (unsigned int i = 0; i < num_funcs; i++) {{ + const char* func_name = nullptr; + result = cuFuncGetName(&func_name, func_list[i]); + if (result != CUDA_SUCCESS || func_name == nullptr) {{ + std::cerr << "Failed to get function name: " << result << "\\n"; + return result; + }} + + std::string name_str(func_name); + size_t pos = name_str.find(pattern); + + if (pos != std::string::npos) {{ + // Found substring match + size_t after_pattern = pos + pattern_len; + + // Check what follows the pattern + if (after_pattern < name_str.length() && + name_str[after_pattern] == '_' && + after_pattern + 1 < name_str.length() && + std::isdigit(name_str[after_pattern + 1])) {{ + // Pattern followed by _digit (e.g., "main_kernel_1") + variant_matches.push_back({{name_str, func_list[i]}}); + }} else {{ + // Pattern not followed by _digit (e.g., "main_kernel" itself) + base_matches.push_back({{name_str, func_list[i]}}); + }} + }} + }} + + // Decision logic: prefer base matches over variant matches + if (!base_matches.empty()) {{ + if (base_matches.size() == 1) {{ + *out_func = base_matches[0].second; + return CUDA_SUCCESS; + }} + + // Multiple base matches - ambiguous + std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size() + << " base kernels (ambiguous). Matches found:\\n"; + for (const auto& match : base_matches) {{ + std::cerr << " - " << match.first << "\\n"; + }} + std::cerr << "Please use a more specific pattern.\\n"; + return CUDA_ERROR_NOT_FOUND; + }} + + // No base matches, try variant matches + if (!variant_matches.empty()) {{ + if (variant_matches.size() == 1) {{ + *out_func = variant_matches[0].second; + return CUDA_SUCCESS; + }} + + // Multiple variant matches - ambiguous + std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size() + << " variant kernels (ambiguous). Matches found:\\n"; + for (const auto& match : variant_matches) {{ + std::cerr << " - " << match.first << "\\n"; + }} + std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\\n"; + return CUDA_ERROR_NOT_FOUND; + }} + + // No matches at all + std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\\n"; + return CUDA_ERROR_NOT_FOUND; +}} + + +// Initialize CUDA module (called once on first launch) +static CUresult tilelang_init_cuda_module(const std::string& cubin_path) {{ + if (g_module_initialized) return CUDA_SUCCESS; + + CUresult result; + result = cuInit(0); + if (result != CUDA_SUCCESS) return result; + + std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary); + if (!cubin_file) {{ + std::cerr << "Failed to open cubin file: " << cubin_path << "\\n"; + return CUDA_ERROR_FILE_NOT_FOUND; + }} + + std::vector cubin_data((std::istreambuf_iterator(cubin_file)), + std::istreambuf_iterator()); + cubin_file.close(); + + if (cubin_data.empty()) {{ + std::cerr << "Empty cubin file: " << cubin_path << "\\n"; + return CUDA_ERROR_INVALID_IMAGE; + }} + + result = cuModuleLoadData(&g_module, cubin_data.data()); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to load CUDA module: " << result << "\\n"; + return result; + }} + + g_module_initialized = true; + return CUDA_SUCCESS; +}} + +// Initialize all kernel functions (called once after module load) +static CUresult tilelang_init_kernels() {{ + if (g_kernels_initialized) return CUDA_SUCCESS; + CUresult result; + +{kernel_inits} + + g_kernels_initialized = true; + return CUDA_SUCCESS; +}} + +// TMA descriptor initialization (host-side) +{tma_init_func} + +// Main kernel launcher +extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, tvm::ffi::Bytes cubin_path) {{ + CUresult result; + + std::string cubin_path_str(reinterpret_cast(cubin_path.data()), cubin_path.size()); + result = tilelang_init_cuda_module(cubin_path_str); + if (result != CUDA_SUCCESS) return result; + + result = tilelang_init_kernels(); + if (result != CUDA_SUCCESS) return result; + +{get_ptr_code} + CUstream stream = (CUstream)_stream; + +{tma_init_in_launch} + +{kernel_launches} + + return CUDA_SUCCESS; +}} + +// Cleanup function +extern "C" CUresult cleanup_module() {{ + if (g_module_initialized && g_module != nullptr) {{ + cuModuleUnload(g_module); + g_module = nullptr; + g_module_initialized = false; + }} + + g_kernels_initialized = false; + + return CUDA_SUCCESS; +}} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module); +""" + +# ============================================================================= +# PYTHON CUBIN GENERATION TEMPLATES +# ============================================================================= + +# TMA descriptor atom initialization template +CUBIN_TMA_ATOM_INIT_TEMPLATE = """\ + {desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))""" + +# Kernel launch call template +CUBIN_KERNEL_LAUNCH_TEMPLATE = """\ + {function_name}({call_args}).launch( + grid=[{grid_x}, {grid_y}, {grid_z}], + block=[{block_x}, {block_y}, {block_z}], + smem={smem_size}, + stream=stream, + )""" + +# Fake tensor creation template +CUBIN_FAKE_TENSOR_TEMPLATE = """\ + __fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)], {arg_name}.shape, stride_order={arg_name}.dim_order()[::-1], assumed_align=16)""" + +# Complete cubin generation code template +# {lib_code} contains the @cute.kernel definitions and is embedded here +CUBIN_GEN_CODE_TEMPLATE = """\ +{lib_code} + + @cute.jit + def kernel_wrapper({wrapper_args}): +{tma_init_code}{kernel_launches} + + # Compile kernels to generate cubin +{fake_tensor_code}{fake_tma_tensor_code} __fake_stream__ = make_fake_stream() + # Always generate cubin under a unique staging directory to avoid concurrent + # processes clobbering each other's intermediate artifacts. + _staging_dir = Path(tempfile.mkdtemp( + prefix=Path(__file__).stem + ".cubin.staging.", + dir=_module_dir, + )) + try: + _kernel_wrapper = cute.compile( + kernel_wrapper, + {compile_args}, + options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}", + ) + + # CuTeDSL generates a long, mangled cubin filename that includes argument/type info, + # e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin. + _cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime) + if len(_cubin_files) != 1: + raise RuntimeError( + f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}" + ) + os.replace(_cubin_files[0], _cubin_path) + finally: + shutil.rmtree(_staging_dir, ignore_errors=True)""" + +# ============================================================================= +# PYTHON HOST FUNCTION TEMPLATE +# ============================================================================= + +PYTHON_HOST_FUNC_TEMPLATE = """\ +import os +from pathlib import Path + +# Minimal imports for runtime (no cutlass/cute - only needed for cubin generation) +import tvm.runtime as runtime + +_cpp_launcher = None +_cpp_launcher_lib = None +_cubin_generated = False + +# Pre-compute paths - cubin is stored alongside the launcher .so +# Use module basename to avoid conflicts when multiple kernels run concurrently +# e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin" +# "kernel.py" (in cache) -> "kernel.cubin" +_module_dir = Path(os.path.dirname(__file__)) +_cubin_path = _module_dir / (Path(__file__).stem + ".cubin") +_cubin_path_bytes = _cubin_path.as_posix().encode('utf-8') +_cubin_needs_generation = not _cubin_path.exists() + +def _generate_cubin_if_needed({cubin_gen_params}): + \"\"\"Generate cubin file on first call. + + All CuTeDSL imports are inside this function to avoid slow + module-level initialization when loading from cache. + \"\"\" + global _cubin_generated, _cubin_path + + # Lazy import CuTeDSL only when cubin generation is needed + from cuda.bindings.driver import CUstream + import cutlass + import cutlass.cute as cute + from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor + import tilelang.contrib.cutedsl as tl + # We rely on CuTeDSL's keep-cubin artifact rather than custom extraction. + import tempfile + import shutil + + _DTYPE_MAP = {{ + "torch.float32": cutlass.Float32, + "torch.float16": cutlass.Float16, + "torch.bfloat16": cutlass.BFloat16, + "torch.float8_e4m3fnuz": cutlass.Float8E4M3FN, + "torch.float8_e4m3fn": cutlass.Float8E4M3FN, + "torch.float8_e5m2": cutlass.Float8E5M2, + "torch.float64": cutlass.Float64, + "torch.int64": cutlass.Int64, + "torch.int32": cutlass.Int32, + "torch.uint32": cutlass.Uint32, + "torch.bool": cutlass.Boolean, + "torch.int8": cutlass.Int8, + "torch.uint8": cutlass.Uint8, + "torch.int16": cutlass.Int16, + "torch.uint16": cutlass.Uint16, + "torch.uchar": cutlass.Uint8, + }} + +{cubin_gen_code} + + _cubin_generated = True + +def _load_cpp_launcher(): + \"\"\"Load C++ kernel launcher.\"\"\" + global _cpp_launcher, _cpp_launcher_lib + if _cpp_launcher is not None: + return _cpp_launcher + + lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}") + if not os.path.exists(lib_path): + raise FileNotFoundError(f"Launcher not found: {{lib_path}}") + + _cpp_launcher_lib = runtime.load_module(lib_path) + _cpp_launcher = _cpp_launcher_lib["launch_kernel"] + return _cpp_launcher + +def call({call_func_params}, stream): + \"\"\"Kernel dispatch function.\"\"\" + global _cubin_path_bytes, _cubin_needs_generation + + if _cubin_needs_generation: + _generate_cubin_if_needed({cubin_gen_call_args}) + _cubin_needs_generation = False + +{arg_prep_code} + + launcher = _load_cpp_launcher() + result = launcher({launcher_call_args}, stream, _cubin_path_bytes) + + if result != 0: + raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}") +""" + +# ============================================================================= +# WRAPPER CLASS +# ============================================================================= + + +class TLCuTeDSLSourceWrapper(TLCUDASourceWrapper): + """Wrapper class for TileLang CuTe DSL backend with C++ launcher. + + Generates optimized C++ launcher code that: + - Loads cubin via CUDA Driver API + - Passes TMA descriptors by value (host-side, no device copy) + - Launches kernels with minimal Python overhead + - Supports both single and multiple kernel scenarios + """ + + _TYPE_MAP: ClassVar[dict[str, str]] = { + "float32": "cutlass.Float32", + "float16": "cutlass.Float16", + "bfloat16": "cutlass.BFloat16", + "float8_e4m3": "cutlass.Float8E4M3", + "float8_e5m2": "cutlass.Float8E5M2", + "float64": "cutlass.Float64", + "int64": "cutlass.Int64", + "int32": "cutlass.Int32", + "uint32": "cutlass.Uint32", + "bool": "cutlass.Boolean", + "int8": "cutlass.Int8", + "uint8": "cutlass.Uint8", + "int16": "cutlass.Int16", + "uint16": "cutlass.Uint16", + "uchar": "cutlass.Uint8", + } + + # C++ launcher code must not depend on cutlass Python types. + # Use plain C/C++ types for expression rendering inside generated .cpp. + _CXX_TYPE_MAP: ClassVar[dict[str, str]] = { + "float32": "float", + "float64": "double", + "int64": "int64_t", + "int32": "int32_t", + "uint32": "uint32_t", + "bool": "bool", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + } + + _CTYPES_MAP: ClassVar[dict[str, str]] = { + "buffer": "ctypes.c_uint64", + "cutlass.Float32": "ctypes.c_float", + "cutlass.Float16": "ctypes.c_uint16", + "cutlass.Float64": "ctypes.c_double", + "cutlass.Int64": "ctypes.c_int64", + "cutlass.Int32": "ctypes.c_int32", + "cutlass.Uint32": "ctypes.c_uint32", + "cutlass.Int8": "ctypes.c_int8", + "cutlass.Uint8": "ctypes.c_uint8", + "cutlass.Int16": "ctypes.c_int16", + "cutlass.Uint16": "ctypes.c_uint16", + "int": "ctypes.c_int32", + } + + _generated_host_func: str | None = None + _launcher_lib_name: str | None = None + + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def host_func(self): + """Override parent's host_func to return generated Python code.""" + if self._generated_host_func is not None: + return self._generated_host_func + return super().host_func + + @host_func.setter + def host_func(self, value): + """Allow setting generated host function code.""" + self._generated_host_func = value + + # ========================================================================= + # Utility Methods + # ========================================================================= + + def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + """Convert TVM expression to Python string.""" + return pythonic_expr(expr, self._TYPE_MAP, floor_div_op="//") + + def _cxx_expr(self, expr: tvm.tir.PrimExpr) -> str: + """Convert TVM expression to C++ string for generated launcher code.""" + return pythonic_expr(expr, self._CXX_TYPE_MAP) + + @staticmethod + def _cxx_cast(ctype: str, expr_str: str) -> str: + return f"static_cast<{ctype}>({expr_str})" + + def _collect_function_args(self) -> tuple[list[dict], list[str]]: + """Collect all function arguments from primary function. + + Returns: + Tuple of (function_args, buffer_args) + """ + function_args = [] + buffer_args = [] + + for param in self.prim_func.params: + if param in self.prim_func.buffer_map: + buffer = self.prim_func.buffer_map[param] + function_args.append({"name": buffer.data.name, "type": "buffer"}) + buffer_args.append(buffer.data.name) + elif isinstance(param, tvm.tir.Var): + function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]}) + else: + raise ValueError(f"Parameter {param} not in buffer map") + + existing_names = {arg["name"] for arg in function_args} + for dyn_sym in self.get_dynamic_symbolic_set(self.prim_func): + dyn_sym_name, dyn_sym_dtype = dyn_sym if isinstance(dyn_sym, tuple) else (dyn_sym, "int32") + if dyn_sym_name in existing_names: + continue + existing_names.add(dyn_sym_name) + function_args.append({"name": dyn_sym_name, "type": self._TYPE_MAP.get(dyn_sym_dtype, "int")}) + + return function_args, buffer_args + + @staticmethod + def _extract_func_call_args( + declaration: str, + function_args: list[dict], + function_params: list, + desc_name_map: dict[str, str] | None = None, + desc_name_var_map: dict[str, tvm.tir.Var] | None = None, + ) -> list[tuple[str, str]]: + """Extract function call arguments from Python function declaration.""" + + def maybe_desc(name: str | tuple[str, str], param_names: list[str], i: int): + name_str = name if isinstance(name, str) else name[0] + param = param_names[i] + if not (param == name_str + "_desc" or param.startswith(name_str + "_desc_")): + return False + if desc_name_map is not None: + desc_name_map[param] = name_str + return True + + def extract_param_names_ast(decl: str) -> list[str] | None: + """Extract parameter names using AST parsing.""" + import ast + import warnings + + try: + # Build a syntactically valid function by adding a body + func_stub = decl.rstrip() + if not func_stub.endswith(":"): + func_stub += ":" + func_stub += "\n pass" + + # Parse and locate the FunctionDef + tree = ast.parse(func_stub) + func_def = None + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_def = node + break + + if func_def is None: + return None + + # Extract parameter names, skipping 'self' + param_names = [] + for arg in func_def.args.args: + if arg.arg != "self": + param_names.append(arg.arg) + + return param_names + except Exception as e: + warnings.warn(f"AST parsing failed for function declaration, falling back to split-based parsing: {e}", stacklevel=2) + return None + + def extract_param_names_split(decl: str) -> list[str]: + """Fallback: extract parameter names using naive split-based parsing.""" + paren_start = decl.find("(") + paren_end = decl.rfind(")") + if paren_start == -1 or paren_end == -1: + return [] + + params_str = decl[paren_start + 1 : paren_end].strip() + if not params_str: + return [] + + param_parts = params_str.split(",") + param_names = [] + for param in param_parts: + param = param.strip() + if not param or param == "self": + continue + if ":" in param: + param_name = param.split(":")[0].strip() + else: + param_name = param.strip() + param_names.append(param_name) + + return param_names + + # Try AST-based extraction first, fallback to split-based + param_names = extract_param_names_ast(declaration) + if param_names is None: + param_names = extract_param_names_split(declaration) + + call_args = [] + for i, param_name in enumerate(param_names): + for arg in function_args: + if arg["name"] == param_name: + call_args.append((param_name, arg["type"])) + elif maybe_desc(arg["name"], param_names, i): + call_args.append((param_name, "None")) + if desc_name_var_map is not None and function_params is not None: + assert len(call_args) <= len(function_params) + desc_name_var_map[param_name] = function_params[len(call_args) - 1] + return call_args + + @staticmethod + def _filter_non_descriptor_args( + call_args: list[tuple[str, str]], desc_names: list[str], tma_tensors: list[str] + ) -> list[tuple[str, str]]: + """Filter out descriptor arguments.""" + filtered = [] + for arg_name, arg_type in call_args: + if "desc" in arg_name and arg_name in desc_names: + continue + if arg_name in tma_tensors: + continue + filtered.append((arg_name, arg_type)) + return filtered + + # ========================================================================= + # TMA Descriptor Code Generation + # ========================================================================= + + def _generate_tma_desc_init(self, desc_name: str, desc_idx: int, tensor_name: str, info: dict) -> str: + """Generate single TMA descriptor initialization code.""" + if info.get("is_img2col", False): + rank = info["tensor_rank"] + return CPP_TMA_IM2COL_DESC_INIT_TEMPLATE.format( + desc_idx=desc_idx, + desc_name=desc_name, + tensor_name=tensor_name, + rank=rank, + stride_rank=rank - 1, + rank_minus_two=rank - 2, + global_dim_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_dim"]), + global_stride_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_stride"][1:]), + elem_stride_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["element_strides"]), + lower_corner_values=", ".join(self._cxx_cast("int32_t", self._cxx_expr(x)) for x in info["lower_corner"]), + upper_corner_values=", ".join(self._cxx_cast("int32_t", self._cxx_expr(x)) for x in info["upper_corner"]), + # Match NVRTC wrapper naming: channelsPerPixel then pixelsPerColumn + channels_per_pixel=info["smem_box_channel"], + pixels_per_column=info["smem_box_pixel"], + dtype=info["dtype"], + interleave=info["interleave"], + swizzle=info["swizzle"], + l2_promotion=info["l2Promotion"], + oob_fill=info["oobFill"], + ) + + return CPP_TMA_DESC_INIT_TEMPLATE.format( + desc_idx=desc_idx, + desc_name=desc_name, + tensor_name=tensor_name, + rank=info["tensor_rank"], + global_dim_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_dim"]), + stride_rank=info["tensor_rank"] - 1, + global_stride_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_stride"][1:]), + box_dim_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["box_dim"]), + elem_stride_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["element_strides"]), + dtype=info["dtype"], + interleave=info["interleave"], + swizzle=info["swizzle"], + l2_promotion=info["l2Promotion"], + oob_fill=info["oobFill"], + ) + + def _generate_tma_init_func( + self, + desc_names: list[str], + tensor_args: list[str], + tensor_arg_map: dict[str, tuple[str, int]], + scalar_args: list[dict[str, str]], + ) -> str: + """Generate TMA init function code (creates descriptors in caller-provided host array). + + TMA descriptors are stored in stack-local tma_descs[] array in launch_kernel. + cuLaunchKernel automatically handles __grid_constant__ params. + """ + if not desc_names: + return "" + + func_args_parts = [f"uint64_t {arg}_ptr" for arg in tensor_args] + for arg in scalar_args: + if arg["type"] in ["int", "cutlass.Int32"]: + func_args_parts.append(f"int32_t {arg['name']}") + elif arg["type"] in ["float", "cutlass.Float32"]: + func_args_parts.append(f"float {arg['name']}") + else: + # Default to int32_t for scalars used in shape/stride math + func_args_parts.append(f"int32_t {arg['name']}") + func_args = ", ".join(func_args_parts) + num_descs = len(desc_names) + + desc_inits = [] + for idx, desc_name in enumerate(desc_names): + info = self.tma_desc_info[desc_name] + tensor_name, _ = tensor_arg_map[desc_name] + desc_inits.append(self._generate_tma_desc_init(desc_name, idx, tensor_name, info)) + + return CPP_TMA_INIT_FUNC_TEMPLATE.format( + func_args=func_args, + num_descs=num_descs, + desc_init_code="\n".join(desc_inits), + ) + + def _generate_tma_launch_init( + self, desc_names: list[str], tma_tensors: list[str], scalar_args: list[dict[str, str]], num_tma_descs: int + ) -> str: + """Generate TMA initialization code for launch function (host memory mode). + + TMA descriptors stay on host. cuLaunchKernel copies them to param space + when kernel uses __grid_constant__ CUtensorMap parameter. + """ + if not desc_names: + return "" + + # Generate tma_init call args (no device_ptr needed) + call_args_parts = [f"{arg}_ptr" for arg in tma_tensors] + [arg["name"] for arg in scalar_args] + tma_tensor_args = ", ".join(call_args_parts) + + return CPP_TMA_LAUNCH_INIT_TEMPLATE.format( + num_tma_descs=num_tma_descs, + tma_tensor_args=tma_tensor_args, + ) + + # ========================================================================= + # Kernel Code Generation + # ========================================================================= + + def _generate_kernel_init(self, kernel_idx: int, kernel_name: str, smem_size: int) -> str: + """Generate kernel initialization code.""" + return CPP_KERNEL_INIT_TEMPLATE.format( + kernel_idx=kernel_idx, + kernel_name=kernel_name, + smem_size=smem_size, + ) + + def _generate_kernel_launch(self, kernel_meta: dict, kernel_idx: int, all_desc_names: list[str]) -> str: + """Generate single kernel launch code. + + For __grid_constant__ CUtensorMap params: + - Pass CUtensorMap* directly (not CUtensorMap**) + - cuLaunchKernel copies 128 bytes to kernel param space + """ + call_args = kernel_meta["call_args"] + desc_names = kernel_meta["desc_names"] + function_info = kernel_meta["function_info"] + + # Build kernel args + kernel_args = [] + for arg_name, arg_type in call_args: + if "desc" in arg_name and arg_name in desc_names: + # For __grid_constant__ CUtensorMap: pass host pointer directly + # cuLaunchKernel will copy 128-byte CUtensorMap to param space + desc_idx = all_desc_names.index(arg_name) + kernel_args.append(f"&tma_descs[{desc_idx}]") + elif arg_type == "buffer": + kernel_args.append(f"&{arg_name}_ptr") + else: + kernel_args.append(f"&{arg_name}") + + grid = function_info["grid_info"] + block = function_info["block_info"] + smem_size = function_info["dynamic_smem_buf"] or 0 + + return CPP_KERNEL_LAUNCH_TEMPLATE.format( + kernel_idx=kernel_idx, + kernel_name=kernel_meta["function_name"], + kernel_args=", ".join(kernel_args), + grid_x=self._cxx_expr(grid[0]), + grid_y=self._cxx_expr(grid[1]), + grid_z=self._cxx_expr(grid[2]), + block_x=self._cxx_expr(block[0]), + block_y=self._cxx_expr(block[1]), + block_z=self._cxx_expr(block[2]), + smem_size=smem_size, + ) + + # ========================================================================= + # C++ Launcher Generation + # ========================================================================= + + def _generate_cpp_launcher( + self, + kernel_metadata_list: list[dict], + function_args: list[dict], + all_tma_tensors: list[str], + all_desc_names: list[str], + tensor_arg_map: dict[str, tuple[str, int]], + ) -> str: + """Generate complete C++ launcher code using templates. + + TMA descriptors are stored on HOST memory in stack-local tma_descs[] array. + cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space + when kernel uses __grid_constant__ parameter. + """ + num_kernels = len(kernel_metadata_list) + num_tma_descs = max(len(all_desc_names), 1) # At least 1 to avoid zero-size array + + # Generate kernel inits + kernel_inits = "\n".join( + self._generate_kernel_init(idx, km["function_name"], km["function_info"]["dynamic_smem_buf"] or 0) + for idx, km in enumerate(kernel_metadata_list) + ) + + # Generate TMA init function + scalar_args = [arg for arg in function_args if arg["type"] != "buffer"] + tma_init_func = self._generate_tma_init_func(all_desc_names, all_tma_tensors, tensor_arg_map, scalar_args) + + # Generate launch function signature and get_ptr code + func_sig_parts = [] + get_ptr_code = "" + for arg in function_args: + if arg["type"] == "buffer": + func_sig_parts.append(f"tvm::ffi::TensorView {arg['name']}") + get_ptr_code += f" uint64_t {arg['name']}_ptr = reinterpret_cast({arg['name']}.data_ptr());\n" + elif arg["type"] in ["int", "cutlass.Int32"]: + func_sig_parts.append(f"int32_t {arg['name']}") + elif arg["type"] in ["float", "cutlass.Float32"]: + func_sig_parts.append(f"float {arg['name']}") + else: + func_sig_parts.append(f"int32_t {arg['name']}") + + # Generate TMA init in launch + tma_init_in_launch = self._generate_tma_launch_init(all_desc_names, all_tma_tensors, scalar_args, num_tma_descs) + + # Generate kernel launches + kernel_launches = "\n".join(self._generate_kernel_launch(km, idx, all_desc_names) for idx, km in enumerate(kernel_metadata_list)) + + return CPP_LAUNCHER_TEMPLATE.format( + num_kernels=num_kernels, + num_tma_descs=num_tma_descs, + kernel_inits=kernel_inits, + tma_init_func=tma_init_func, + launch_func_sig=", ".join(func_sig_parts), + get_ptr_code=get_ptr_code, + tma_init_in_launch=tma_init_in_launch, + kernel_launches=kernel_launches, + ) + + # ========================================================================= + # Python Wrapper Generation + # ========================================================================= + + def _generate_cubin_gen_code( + self, + kernel_metadata_list: list[dict], + buffer_args: list[str], + all_desc_names: list[str], + lib_code: str = "", + ) -> str: + """Generate cubin generation code for Python wrapper using templates. + + Args: + lib_code: The CuTeDSL kernel definitions (@cute.kernel decorated functions). + This will be embedded inside _generate_cubin_if_needed to enable + lazy loading of cutlass/cute modules. + """ + # Build unified wrapper parameters + wrapper_params_union = [] + for kernel_meta in kernel_metadata_list: + for arg_name, _ in kernel_meta["call_args"]: + if arg_name not in wrapper_params_union: + wrapper_params_union.append(arg_name) + + # Build inner args for cute.compile + inner_args = [] + fake_inner_args = [] + for arg_name in wrapper_params_union: + if arg_name in buffer_args: + inner_args.append(f"{arg_name}_") + fake_inner_args.append(f"__fake_{arg_name}__") + elif arg_name in all_desc_names: + continue + else: + inner_args.append(arg_name) + fake_inner_args.append(arg_name) + if all_desc_names: + inner_args.append("__fake_tensor__") + fake_inner_args.append("__fake_tensor__") + fake_inner_args.append("__fake_stream__") + + # Generate TMA init code + tma_init_code = "" + if all_desc_names: + tma_init_lines = [" # Create dummy TMA atoms for compilation"] + tma_init_lines.extend(CUBIN_TMA_ATOM_INIT_TEMPLATE.format(desc_name=desc_name) for desc_name in all_desc_names) + tma_init_code = "\n".join(tma_init_lines) + "\n" + + # Generate kernel launch calls + kernel_launches = "\n".join( + CUBIN_KERNEL_LAUNCH_TEMPLATE.format( + function_name=km["function_name"], + call_args=", ".join(arg[0] if arg[0] not in buffer_args else f"{arg[0]}_" for arg in km["call_args"]), + grid_x=self._pythonic_expr(km["function_info"]["grid_info"][0]), + grid_y=self._pythonic_expr(km["function_info"]["grid_info"][1]), + grid_z=self._pythonic_expr(km["function_info"]["grid_info"][2]), + block_x=self._pythonic_expr(km["function_info"]["block_info"][0]), + block_y=self._pythonic_expr(km["function_info"]["block_info"][1]), + block_z=self._pythonic_expr(km["function_info"]["block_info"][2]), + smem_size=km["function_info"]["dynamic_smem_buf"] or 0, + ) + for km in kernel_metadata_list + ) + + # Generate fake tensor creation code + # IMPORTANT: Generate fake tensors based on the *union* of parameters actually + # passed to cute.compile (wrapper_params_union). + # + # In multi-kernel cases, a tensor may appear both as a TMA descriptor + # (e.g. Output_partial_desc) for one kernel and as a plain tensor argument + # (e.g. Output_partial_) for another kernel. Skipping fake tensor creation + # just because a matching "{arg}_desc" exists is a correctness bug and + # results in undefined names like "__fake_Output_partial__". + fake_tensor_code = "\n".join( + CUBIN_FAKE_TENSOR_TEMPLATE.format(arg_name=arg_name) for arg_name in wrapper_params_union if arg_name in buffer_args + ) + if fake_tensor_code: + fake_tensor_code += "\n" + + # Generate fake TMA tensor code + fake_tma_tensor_code = "" + if all_desc_names: + fake_tma_tensor_code = ( + " __fake_tensor__ = make_fake_compact_tensor(cutlass.Int32, (32, 32), stride_order=(1, 0), assumed_align=16)\n" + ) + + # Indent lib_code to be inside the function + indented_lib_code = "\n".join(" " + line if line.strip() else line for line in lib_code.split("\n")) if lib_code else "" + + return CUBIN_GEN_CODE_TEMPLATE.format( + lib_code=indented_lib_code, + wrapper_args=", ".join(inner_args + ["stream: CUstream"]), + tma_init_code=tma_init_code, + kernel_launches=kernel_launches, + fake_tensor_code=fake_tensor_code, + fake_tma_tensor_code=fake_tma_tensor_code, + compile_args=", ".join(fake_inner_args), + primary_name=kernel_metadata_list[0]["function_name"], + ) + + def _generate_python_wrapper( + self, + function_args: list[dict], + cubin_gen_code: str, + cubin_gen_params: str, + ) -> str: + """Generate Python wrapper code.""" + # Build function parameters + call_func_params = ", ".join(arg["name"] for arg in function_args) + launcher_call_args = ", ".join(arg["name"] for arg in function_args) + + return PYTHON_HOST_FUNC_TEMPLATE.format( + cubin_gen_params=cubin_gen_params, + cubin_gen_code=cubin_gen_code, + launcher_lib_name=self._launcher_lib_name, + call_func_params=call_func_params, + cubin_gen_call_args=cubin_gen_params, + arg_prep_code="", + launcher_call_args=launcher_call_args, + ) + + # ========================================================================= + # TMA Descriptor Processing + # ========================================================================= + + def _process_tma_descriptors(self, desc_names: list[str]) -> tuple[list[str], dict[str, tuple[str, int]]]: + """Process TMA descriptors and return tensor args and mapping. + + Returns: + Tuple of (tensor_args, tensor_arg_map) + """ + if not hasattr(self, "tma_desc_info") or not desc_names: + return [], {} + + tensor_args = [] + tensor_arg_map = {} + + for desc_name in desc_names: + info = self.tma_desc_info[desc_name] + # Extract the base buffer variable name (must be a Var, not arbitrary expression) + global_addr = info["globalAddress"] + if not isinstance(global_addr, tvm.tir.Var): + raise ValueError(f"TMA globalAddress must be a buffer Var, got {type(global_addr)}: {global_addr}") + tensor_name = global_addr.name + + if tensor_name not in tensor_args: + tensor_args.append(tensor_name) + tensor_arg_map[desc_name] = (tensor_name, len(tensor_args) - 1) + else: + tensor_arg_map[desc_name] = (tensor_name, tensor_args.index(tensor_name)) + + return tensor_args, tensor_arg_map + + def generate_tma_descriptor_args( + self, + desc_name_map: dict[str, str], + desc_name_var_map: dict[str, tvm.tir.Var], + tma_desc_code_map: dict[str, str], + ) -> list[str]: + """Generate TMA descriptor information for C++ code generation. + + Returns: + List of descriptor variable names in the order they were processed. + """ + if self.tma_descriptor_args is None: + return [] + + if not hasattr(self, "tma_desc_info"): + self.tma_desc_info = {} + + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) + + desc_names_ordered = [] + + for params in parsed_params: + handle_name = params.handle_name + + if handle_name in tma_desc_code_map: + continue + + desc_var = desc_name_var_map[handle_name] + args = self.tma_descriptor_args[desc_var] + _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:] + tensor_rank = int(tensor_rank) + + global_dim = remaining_args[:tensor_rank] + global_stride = remaining_args[tensor_rank : 2 * tensor_rank] + + if not params.is_img2col: + box_dim = remaining_args[2 * tensor_rank : 3 * tensor_rank] + element_strides = remaining_args[3 * tensor_rank : 4 * tensor_rank] + + self.tma_desc_info[handle_name] = { + "desc_var": desc_var, + "is_img2col": False, + "dtype": params.dtype, + "tensor_rank": params.tensor_rank, + "globalAddress": params.global_address, + "global_dim": global_dim, + "global_stride": global_stride, + "box_dim": box_dim, + "element_strides": element_strides, + "interleave": params.interleave, + "swizzle": params.swizzle, + "l2Promotion": params.l2_promotion, + "oobFill": params.oob_fill, + } + else: + element_strides = remaining_args[2 * tensor_rank : 3 * tensor_rank] + + self.tma_desc_info[handle_name] = { + "desc_var": desc_var, + "is_img2col": True, + "dtype": params.dtype, + "tensor_rank": params.tensor_rank, + "globalAddress": params.global_address, + "global_dim": global_dim, + "global_stride": global_stride, + "element_strides": element_strides, + "lower_corner": params.lower_corner, + "upper_corner": params.upper_corner, + "smem_box_channel": params.smem_box_channel, + "smem_box_pixel": params.smem_box_pixel, + "interleave": params.interleave, + "swizzle": params.swizzle, + "l2Promotion": params.l2_promotion, + "oobFill": params.oob_fill, + } + + tma_desc_code_map[handle_name] = "" + desc_names_ordered.append(handle_name) + + return desc_names_ordered + + # ========================================================================= + # Main Entry Points + # ========================================================================= + + def create_dispatch_func(self, code, function_informations): + """Create dispatch function - always use C++ launcher.""" + return self.create_dispatch_func_cpp_launcher(code, function_informations) + + def create_dispatch_func_cpp_launcher(self, code, function_informations): + """Create dispatch function using C++ launcher.""" + function_args, buffer_args = self._collect_function_args() + + # Process each kernel and collect metadata + kernel_metadata = [] + all_desc_names_union = [] + all_tma_tensors_union = [] + + for function_name, function_info in function_informations.items(): + declaration = extract_python_func_declaration(code, function_name) + desc_name_map: dict[str, str] = {} + desc_name_var_map: dict[str, tvm.tir.Var] = {} + call_args = self._extract_func_call_args( + declaration, + function_args, + function_info["function_params"], + desc_name_map, + desc_name_var_map, + ) + + tma_desc_code_map = {} + desc_names = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map) + + tma_tensor_args, _ = self._process_tma_descriptors(desc_names) + + kernel_metadata.append( + { + "function_name": function_name, + "function_info": function_info, + "call_args": call_args, + "desc_names": desc_names, + "tma_tensor_args": tma_tensor_args, + "desc_name_map": desc_name_map, + } + ) + + for desc in desc_names: + if desc not in all_desc_names_union: + all_desc_names_union.append(desc) + for t in tma_tensor_args: + if t not in all_tma_tensors_union: + all_tma_tensors_union.append(t) + + # Process all TMA descriptors + all_tma_tensors, tensor_arg_map = self._process_tma_descriptors(all_desc_names_union) + + # Generate C++ launcher + launcher_cpp_code = self._generate_cpp_launcher( + kernel_metadata, function_args, all_tma_tensors, all_desc_names_union, tensor_arg_map + ) + + self.launcher_cpp_code = launcher_cpp_code + # Use a deterministic name so that: + # 1) the generated kernel.py can always locate the launcher in the same directory + # 2) KernelCache can store it under a stable filename + self._launcher_lib_name = "launcher_lib.so" + self.launcher_lib_name = self._launcher_lib_name + + # Generate cubin generation code (includes lib_code with @cute.kernel definitions) + cubin_gen_code = self._generate_cubin_gen_code( + kernel_metadata, buffer_args, all_desc_names_union, lib_code=getattr(self, "lib_code", "") + ) + + # Generate Python wrapper + buffer_names = [arg["name"] for arg in function_args if arg["type"] == "buffer"] + # Cubin generation may reference scalar args (e.g., dynamic symbols like m/n/k) + # inside `kernel_wrapper` and `cute.compile(...)`. They must be visible in + # `_generate_cubin_if_needed(...)` scope, so include them in its signature. + scalar_names = [arg["name"] for arg in function_args if arg["type"] != "buffer"] + cubin_gen_params = ", ".join(buffer_names + scalar_names) + + python_wrapper = self._generate_python_wrapper(function_args, cubin_gen_code, cubin_gen_params) + + return python_wrapper + + def get_launcher_cpp_code(self) -> str: + """Get the generated C++ launcher code.""" + return getattr(self, "launcher_cpp_code", "") + + def update_lib_code(self, code: str): + """Update the library code with the given code string.""" + self.lib_code = code + + function_informations = {} + for function_name in self.function_names: + if (function_name not in self.block_info) or (function_name not in self.grid_info): + continue + + assert function_name in self.device_mod, f"Function {function_name} not found in device module" + device_func = self.device_mod[function_name] + kernel_params_cnt = len(device_func.params) + function_params: list[str] = None + + def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): + nonlocal function_params + if isinstance(node, tvm.tir.Call): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + return + args = node.args + if not args or args[0] != fn: + return + if len(args) < 1 + param_cnt: + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] + + post_order_visit(self.host_func.body, visitor) + assert function_params is not None, "function_params should not be None" + + function_informations[function_name] = { + "function_name": function_name, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + "function_params": function_params, + } + + self.host_func = self.create_dispatch_func(code, function_informations) + return self.lib_code diff --git a/tilelang/original/tilelang/jit/adapter/cython/__init__.py b/tilelang/original/tilelang/jit/adapter/cython/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc945cd564858b6ce40244e7ae7c9d545ec180d --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cython/__init__.py @@ -0,0 +1 @@ +from .adapter import CythonKernelAdapter # noqa: F401 diff --git a/tilelang/original/tilelang/jit/adapter/cython/adapter.py b/tilelang/original/tilelang/jit/adapter/cython/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..c456e4dbaa38f6bcd341198f27ca56364ae1fb68 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cython/adapter.py @@ -0,0 +1,387 @@ +"""The profiler and convert to torch utils""" + +from __future__ import annotations +import ctypes +import logging +import torch + +from typing import Callable, Any +from tilelang import tvm as tvm +from tvm.target import Target +from tilelang.engine.param import KernelParam +from tvm import tir +from tvm.relax import TensorType + +from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.wrapper import TLWrapper +from tilelang.jit.adapter.libgen import LibraryGenerator +from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target, is_metal_target +from tilelang.utils.target import determine_target +from tilelang.utils.language import retrieve_func_from_module +from tilelang.utils.tensor import map_torch_type + +logger = logging.getLogger(__name__) + +try: + from tilelang_cython_wrapper import CythonKernelWrapper +except ImportError: + raise + + +def is_symbolic_expr(expr) -> bool: + """Check if the expression is a symbolic expression. + A symbolic expression can be a simple tvm.Var, or an tvm.PrimExpr containing tvm.Var. + """ + return not isinstance(expr, tir.IntImm) and isinstance(expr, tir.PrimExpr) + + +class CythonKernelAdapter(BaseKernelAdapter): + """Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython. + + This adapter handles: + 1. Converting TIR functions to compiled CUDA libraries + 2. Managing dynamic shapes in tensor operations + 3. Wrapping C++ kernels for Python/PyTorch usage + """ + + # Class attributes to store compiled kernel information + target: str | Target = "cuda" + ir_module: tvm.IRModule | None = None + # The global source code of the kernel -> global means the source code of the kernel + # that is not wrapped by the wrapper code + host_kernel_source: str | None = None + device_kernel_source: str | None = None + lib: ctypes.CDLL | None = None # Compiled library handle + # Maps symbolic variables to their corresponding buffer and shape indices + dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None + # Maps pointer arguments to their corresponding (buffer_index, shape_dimension) + ptr_map: dict[int, str] | None = None + # Maps buffer variables to their corresponding dtypes + buffer_dtype_map: dict[tir.Var, tuple[int, torch.dtype]] | None = None + # Maps buffer variables to their corresponding static shapes and strides, + # e.g., { + # "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16) + # } + static_shape_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None + static_strides_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None + # Contains contiguous buffers + static_contiguous_list: list[tir.Var] | None = None + # Maps buffer variables to their corresponding devices + buffer_device_map: dict[tir.Var, tuple[int, torch.device]] | None = None + # Pass configs for the compiler + pass_configs: dict[str, Any] | None = None + + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + """Initialize the adapter with the given TIR function or module. + + Args: + params: List of tensor types for inputs/outputs + result_idx: Indices of output tensors + target: Target platform (e.g., 'cuda') + func_or_mod: TIR function or module to be compiled + verbose: Enable verbose logging + """ + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + self.ir_module = func_or_mod + + self.target = Target.canon_target(determine_target(target)) + + self.dynamic_symbolic_map = self._process_dynamic_symbolic() + self.buffer_dtype_map = self._process_buffer_dtype() + self.ptr_map = self._process_ptr_map() + self.buffer_device_map = self._process_buffer_device() + + static_buffer_infos = self._process_static_buffer_infos() + self.static_shape_map = static_buffer_infos[0] + self.static_strides_map = static_buffer_infos[1] + self.static_contiguous_list = static_buffer_infos[2] + + self.verbose = verbose + self.wrapper = TLWrapper(self.target) + self.lib_generator = LibraryGenerator(self.target, verbose=verbose) + self.lib_generator.assign_pass_configs(pass_configs) + self.lib_generator.assign_compile_flags(compile_flags) + + self.wrapper.assign_optimized_module(self.ir_module) + self.wrapper.assign_pass_configs(pass_configs) + self.wrapper.assign_host_module(host_mod) + self.wrapper.assign_device_module(device_mod) + self.host_kernel_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) + + self.lib_generator.update_lib_code(self.host_kernel_source) + self.lib_generator.compile_lib() + self.lib = self.lib_generator.load_lib() + + self.lib.get_last_error.restype = ctypes.c_char_p + result = self.lib.init() + if result != 0: + error_msg = self.lib.get_last_error().decode("utf-8") + error_msg += f"\n{self.lib_code}" + raise RuntimeError(f"Initialization failed: {error_msg}") + + self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib) + self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map) + self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map) + self.cython_wrapper.set_static_shape_map(self.static_shape_map) + self.cython_wrapper.set_static_strides_map(self.static_strides_map) + self.cython_wrapper.set_static_contiguous_list(self.static_contiguous_list) + self.cython_wrapper.set_buffer_device_map(self.buffer_device_map) + self.cython_wrapper.set_ptr_map(self.ptr_map) + self._post_init() + + @classmethod + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + adapter.pass_configs = pass_configs + + if isinstance(func_or_mod, tir.PrimFunc): + adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + target = determine_target(target, return_object=True) + adapter.target = Target.canon_target(determine_target(target)) + + adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic() + adapter.buffer_dtype_map = adapter._process_buffer_dtype() + adapter.ptr_map = adapter._process_ptr_map() + adapter.buffer_device_map = adapter._process_buffer_device() + + static_buffer_infos = adapter._process_static_buffer_infos() + adapter.static_shape_map = static_buffer_infos[0] + adapter.static_strides_map = static_buffer_infos[1] + adapter.static_contiguous_list = static_buffer_infos[2] + + adapter.verbose = verbose + adapter.lib_generator = LibraryGenerator(adapter.target, verbose=verbose) + adapter.lib_generator.assign_pass_configs(pass_configs) + adapter.lib_generator.assign_compile_flags(compile_flags) + adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path) + + adapter.lib.get_last_error.restype = ctypes.c_char_p + result = adapter.lib.init() + if result != 0: + error_msg = adapter.lib.get_last_error().decode("utf-8") + raise RuntimeError(f"Initialization failed: {error_msg}") + + adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, adapter.lib) + adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map) + adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map) + adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map) + adapter.cython_wrapper.set_static_strides_map(adapter.static_strides_map) + adapter.cython_wrapper.set_static_contiguous_list(adapter.static_contiguous_list) + adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map) + adapter.cython_wrapper.set_ptr_map(adapter.ptr_map) + + adapter._post_init() + return adapter + + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]: + """Extract information about dynamic shapes from the TIR function. + + Maps symbolic variables to their corresponding (id, buffer_index, dimension) + for runtime shape resolution. + id represents shape or stride, 0 represents shape, 1 represents stride + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map = {} + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): + dynamic_symbolic_map[shape] = (0, i, j) + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, stride in enumerate(buffer.strides): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): + dynamic_symbolic_map[stride] = (1, i, j) + return dynamic_symbolic_map + + def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]: + """Extract information about buffer dtypes from the TIR function. + + Maps buffer variables to their corresponding dtypes. + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + buffer_dtype_map = {} + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + name, dtype = buffer.name, buffer.dtype + buffer_dtype_map[name] = (i, map_torch_type(dtype)) + return buffer_dtype_map + + def _process_ptr_map(self) -> dict[int, str]: + """Extract information about pointer arguments from the TIR function. + + Maps pointer arguments to their corresponding (buffer_index, shape_dimension) + for runtime shape resolution. + """ + func = self.prim_func + params = func.params + ptr_map = {} + for i, param in enumerate(params): + if param.dtype == "handle": + ptr_map[i] = param.name + return ptr_map + + def _process_static_buffer_infos( + self, + ) -> tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], dict[tir.Var, tuple[int, list[tuple[int, int]]]], list[tuple[tir.Var]]]: + """Extract information about static shapes from the TIR function. + + Maps buffer variables to their corresponding static shapes. + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + static_shape_map = {} + static_strides_map = {} + static_contiguous_list = list() + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + static_shape, static_strides = [], [] + for j, s in enumerate(buffer.shape): + if isinstance(s, tir.IntImm): + static_shape.append((j, s.value)) + elif is_symbolic_expr(s): + static_shape.append((j, -1)) # -1 for symbolic + else: + raise ValueError(f"Unsupported shape type: {type(s)}") + for j, s in enumerate(buffer.strides): + if isinstance(s, tir.IntImm): + static_strides.append((j, s.value)) + is_contiguous, prod = True, 1 + for dim, stride in reversed(list(zip(buffer.shape, buffer.strides))): + is_contiguous &= bool(stride == prod) + prod *= dim + static_shape_map[buffer.name] = (i, static_shape) + static_strides_map[buffer.name] = (i, static_strides) + if is_contiguous: + static_contiguous_list.append((i, buffer.name)) + return static_shape_map, static_strides_map, static_contiguous_list + + def _process_buffer_device(self) -> dict[tir.Var, tuple[int, torch.device]]: + """Extract information about buffer devices from the TIR function. + + Maps buffer variables to their corresponding devices. + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + buffer_device_map = {} + device = None + if is_cuda_target(self.target) or is_hip_target(self.target): + device = torch.device("cuda") + elif is_cpu_target(self.target): + device = torch.device("cpu") + elif is_metal_target(self.target): + device = torch.device("mps") + else: + raise ValueError(f"Unsupported target: {self.target}") + + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + name = buffer.name + buffer_device_map[name] = (i, device) + return buffer_device_map + + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): + """Low-level function to call the compiled CUDA kernel. + + Converts PyTorch tensor pointers to C void pointers for ctypes interface. + """ + ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args] + ctypes_args.append(ctypes.c_void_p(stream)) + self.lib.call(*ctypes_args) + + def _convert_torch_func(self) -> Callable: + """Returns a PyTorch-compatible function wrapper for the kernel.""" + + def lambda_forward(*args, stream: int = -1, skip_tensor_validation: bool = False): + """ + Args: + args: List of input tensors + stream: CUDA stream ID, default to -1, will use the current stream if not specified + skip_tensor_validation: Whether to skip tensor attributes validation which + includes shape, dtype, device, etc. + """ + return self.cython_wrapper.forward([*args], stream=stream, skip_tensor_validation=skip_tensor_validation) + + return lambda_forward + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) + + @property + def srcpath(self): + """Returns the source path of the compiled library.""" + return self.lib_generator.srcpath + + @property + def libpath(self): + """Returns the path to the compiled library.""" + return self.lib_generator.libpath + + @property + def lib_code(self): + """Returns the code of the compiled library.""" + return self.lib_generator.lib_code + + @property + def is_dynamic(self): + """Indicates whether the kernel handles dynamic shapes.""" + return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0 + + def get_kernel_source(self, kernel_only: bool = False): + """Returns the source code of the compiled kernel.""" + if kernel_only: + return self.device_kernel_source + else: + # Wrapper only has host kernel source + assert self.host_kernel_source is not None, "Wrapped source is not available" + return self.host_kernel_source diff --git a/tilelang/original/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/original/tilelang/jit/adapter/cython/cython_wrapper.pyx new file mode 100644 index 0000000000000000000000000000000000000000..dc462c627f3fe8c3b512df7e2e119cd64a057685 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -0,0 +1,289 @@ +# cython: language_level=3 + +import torch +cimport cython +import ctypes +from libc.stdint cimport int64_t, uintptr_t +from libc.stdlib cimport malloc, free +from tvm import tir +from tilelang.utils.tensor import map_torch_type + +cdef class CythonKernelWrapper: + # Class attributes to store kernel configuration and library reference + cdef: + object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices + object buffer_device_map # Maps buffer variables to their corresponding devices + object buffer_dtype_map # Maps buffer variables to their corresponding dtypes + object static_shape_map # Maps buffer variables to their corresponding static shapes + object static_strides_map # Maps buffer variables to their corresponding static strides + object static_contiguous_list # A list contains contiguous buffers + object ptr_map # Maps pointer arguments to their corresponding buffer indices + list result_idx # Indices of output tensors in the params list + list params # List of parameter specifications (includes both inputs and outputs) + object lib # Reference to the compiled library containing the kernel + # Add new cache attributes + list param_dtypes # Cache for parameter dtypes + list param_shapes # Cache for parameter shapes as native Python lists + object get_current_device + + def __cinit__(self, result_idx, params, lib): + # Initialize wrapper with kernel configuration + self.result_idx = result_idx + self.params = params + self.lib = lib + # Convert TVM types to native Python types during initialization + # Convert tvm.DataType to torch.dtype for tensor creation + self.param_dtypes = [param.torch_dtype() for param in params] + # Convert TVM shape arrays to native Python lists + self.param_shapes = [] + self.get_current_device = torch.cuda.current_device + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + native_shape.append(dim) # Keep tir.Var for dynamic dimensions + else: + native_shape.append(dim) + self.param_shapes.append(native_shape) + + def set_dynamic_symbolic_map(self, dynamic_symbolic_map): + self.dynamic_symbolic_map = dynamic_symbolic_map + return self + + def set_buffer_dtype_map(self, buffer_dtype_map): + self.buffer_dtype_map = buffer_dtype_map + return self + + def set_static_shape_map(self, static_shape_map): + self.static_shape_map = static_shape_map + return self + + def set_static_strides_map(self, static_strides_map): + self.static_strides_map = static_strides_map + return self + + def set_static_contiguous_list(self, static_contiguous_list): + self.static_contiguous_list = static_contiguous_list + return self + + def set_ptr_map(self, ptr_map): + self.ptr_map = ptr_map + return self + + def set_buffer_device_map(self, buffer_device_map): + self.buffer_device_map = buffer_device_map + return self + + cpdef void _check_buffer_device(self, list tensor_list): + for param, (buffer_idx, device) in self.buffer_device_map.items(): + tensor = tensor_list[buffer_idx] + if isinstance(tensor, torch.Tensor): + tensor_device = tensor.device + device_type_match = device.type == tensor_device.type + device_index_match = ( + tensor_device.index is None or + device.index is None or + tensor_device.index == device.index + ) + if not (device_type_match and device_index_match): + raise ValueError( + f"Buffer device mismatch for parameter {param}: " + f"expected {device}, got {tensor_device}" + ) + + cpdef void _check_buffer_dtype(self, list tensor_list): + for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items(): + tensor = tensor_list[buffer_idx] + if isinstance(tensor, torch.Tensor) and tensor.dtype != torch_dtype: + raise ValueError( + f"Buffer dtype mismatch for parameter {param}: " + f"expected {torch_dtype}, got {tensor.dtype}" + ) + + cpdef void _check_static_shape(self, list tensor_list): + for param, (buffer_idx, shape_list) in self.static_shape_map.items(): + tensor = tensor_list[buffer_idx] + if not isinstance(tensor, torch.Tensor): + # otherwise, maybe torch.data_ptr() for T.ptr inputs + continue + + # Check ndim + if tensor.dim() != len(shape_list): + raise ValueError( + f"Static shape mismatch for parameter {param}: " + f"expected {len(shape_list)} dimensions, " + f"got {tensor.dim()}" + ) + + # Check each dimension + for shape_idx, expected_shape in shape_list: + actual_shape = tensor.shape[shape_idx] + if expected_shape != -1 and actual_shape != expected_shape: + raise ValueError( + f"Static shape mismatch for parameter {param}: " + f"expected {expected_shape} at index {shape_idx}, " + f"got {actual_shape}" + ) + + cpdef void _check_static_strides(self, list tensor_list): + for param, (buffer_idx, strides_list) in self.static_strides_map.items(): + tensor = tensor_list[buffer_idx] + if not isinstance(tensor, torch.Tensor): + # otherwise, maybe torch.data_ptr() for T.ptr inputs + continue + for stride_idx, expected_stride in strides_list: + # Ensure the stride index is within the valid range of tensor dimensions + # (stride_idx should be less than the number of dimensions of the tensor) + assert stride_idx < tensor.dim(), f"Stride index {stride_idx} out of bounds for tensor with {tensor.dim()} dimensions" + if tensor.shape[stride_idx] == 1: + continue + actual_stride = tensor.stride(stride_idx) + if actual_stride != expected_stride: + raise ValueError( + f"Static stride mismatch for parameter {param}: " + f"expected {expected_stride} at index {stride_idx}, " + f"got {actual_stride}" + ) + + cpdef void _check_static_contiguous(self, list tensor_list): + for buffer_idx, param in self.static_contiguous_list: + tensor = tensor_list[buffer_idx] + if not isinstance(tensor, torch.Tensor): + # otherwise, maybe torch.data_ptr() for T.ptr inputs + continue + if not tensor.is_contiguous(): + raise ValueError(f"Expected parameter {param} to be a contiguous tensor") + + cdef object _infer_output_device(self, list inputs): + for tensor in inputs: + if isinstance(tensor, torch.Tensor): + return tensor.device + return torch.cuda.current_device() + + cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False): + # Validate input dimensions and prepare for kernel execution + cdef int total_params = len(self.params) + cdef int total_inputs = len(inputs) + cdef int total_result_idx = len(self.result_idx) + cdef int total_dynamic_symbolics = len(self.dynamic_symbolic_map) + + # Ensure the number of inputs matches expected parameter count + if total_params != total_inputs + total_result_idx: + raise ValueError( + f"Expected {len(self.params)} inputs, got {len(inputs) + len(self.result_idx)} with {len(inputs)} inputs and {len(self.result_idx)} outputs" + ) + + # Use current CUDA stream if none specified + if stream == -1: + if torch.cuda.is_available(): + try: + stream = torch._C._cuda_getCurrentRawStream(torch.cuda.current_device()) + except ImportError: + stream = torch.cuda.current_stream().cuda_stream + else: + stream = 0 + + cdef int ins_idx = 0 + cdef list tensor_list = [] + device = None + + # Prepare input and output tensors + for i in range(len(self.params)): + if i in self.result_idx: + dtype = self.param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in self.param_shapes[i]: + if isinstance(s, tir.Var): + for key in self.dynamic_symbolic_map: + if(str(s) == str(key)): + ref_id, ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key] + shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) + else: # Already converted to Python int during initialization + shape.append(s) + + if device is None: + device = self._infer_output_device(inputs) + + if len(shape) == 0: + param_name = self.params[i].name if hasattr(self.params[i], 'name') else f'parameter_{i}' + raise ValueError( + f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. " + f"Expected shape: {shape}" + ) + tensor = torch.empty(*shape, dtype=dtype, device=device) + else: + tensor = inputs[ins_idx] + ins_idx += 1 + # TODO(chenggang): remove this check or rewrite by ourselves? + ''' + if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous(): + base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride()) + if torch._debug_has_internal_overlap(base_tensor): + raise ValueError(f"Cannot use an overlapping tensor" + f"(shape={tensor.shape}, strides={tensor.stride()}, " + f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input") + ''' + tensor_list.append(tensor) + + # Convert tensor pointers to C void pointers for kernel call + cdef dict dtype_to_ctype = { + torch.float16: ctypes.c_float, + torch.float32: ctypes.c_float, + torch.float64: ctypes.c_double, + torch.int8: ctypes.c_int8, + torch.int16: ctypes.c_int16, + torch.int32: ctypes.c_int32, + torch.int64: ctypes.c_int64, + torch.bool: ctypes.c_bool, + } + + call_args = [] + for i, tensor in enumerate(tensor_list): + if isinstance(tensor, torch.Tensor): + call_args.append(ctypes.c_void_p(tensor.data_ptr())) + elif isinstance(tensor, (int, float, bool)): + if i in self.ptr_map: + call_args.append(ctypes.c_void_p(tensor)) + else: + dtype = self.param_dtypes[i] + if dtype not in dtype_to_ctype: + raise ValueError(f"Unsupported tensor dtype: {dtype}") + call_args.append(dtype_to_ctype[dtype](tensor)) + elif tensor is None: + call_args.append(ctypes.c_void_p(0)) + else: + raise ValueError(f"Unsupported tensor type: {type(tensor)}") + + # Check buffer device + if not skip_tensor_validation: + self._check_buffer_device(tensor_list) + self._check_buffer_dtype(tensor_list) + self._check_static_shape(tensor_list) + self._check_static_strides(tensor_list) + self._check_static_contiguous(tensor_list) + + # Add dynamic dimension values to kernel arguments + for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): + if ref_id == 0: + call_args.append(ctypes.c_int64(tensor_list[buffer_idx].shape[shape_idx])) + else: + call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx))) + + # Add CUDA stream to kernel arguments + call_args.append(ctypes.c_void_p(stream)) + + # Execute the kernel + result = self.lib.call(*call_args) + if result != 0: + error_msg = self.lib.get_last_error().decode('utf-8') + raise RuntimeError(f"Kernel call failed: {error_msg}") + + # Return output tensor(s) + if len(self.result_idx) == 1: + return tensor_list[self.result_idx[0]] + else: + return [tensor_list[i] for i in self.result_idx] + diff --git a/tilelang/original/tilelang/jit/adapter/libgen.py b/tilelang/original/tilelang/jit/adapter/libgen.py new file mode 100644 index 0000000000000000000000000000000000000000..d67f5b403ee010f3c8b52b5d3c5c802f89e88c3a --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/libgen.py @@ -0,0 +1,172 @@ +from __future__ import annotations +import ctypes +import logging +import os +import subprocess +import tempfile +from typing import Any + +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.transform import PassConfigKey +from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_arch, get_target_compute_version +from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch +from tilelang.env import TILELANG_TEMPLATE_PATH +from tilelang.utils.deprecated import deprecated_warning + +from .utils import is_cpu_target, is_cuda_target, is_hip_target + +logger = logging.getLogger(__name__) + + +class LibraryGenerator: + srcpath: str | None = None + libpath: str | None = None + lib_code: str | None = None + pass_configs: dict[str, Any] | None = None + compile_flags: list[str] | None = None + + def __init__(self, target: Target, verbose: bool = False): + self.target = target + self.verbose = verbose + + def assign_pass_configs(self, pass_configs: dict[str, Any] | None = None): + self.pass_configs = pass_configs + + def assign_compile_flags(self, compile_flags: list[str] | None = None): + if compile_flags is None: + compile_flags = [] + self.compile_flags = compile_flags + + def update_lib_code(self, lib_code: str): + self.lib_code = lib_code + + # Assume currently we only support CUDA compilation + def load_lib(self, lib_path: str | None = None): + if lib_path is None: + lib_path = self.libpath + else: + self.libpath = lib_path + return ctypes.CDLL(lib_path) + + def compile_lib(self, timeout: float = None): + target = self.target + verbose = self.verbose + if is_cuda_target(target): + from tilelang.env import CUTLASS_INCLUDE_DIR + + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 + target_arch = get_target_arch(get_target_compute_version(target)) + libpath = src.name.replace(".cu", ".so") + + if self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH): + deprecated_warning( + "TL_DISABLE_FAST_MATH", + "TL_ENABLE_FAST_MATH", + "0.1.7", + ) + enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, True) + else: + enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False) + + ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None) + verbose_ptxas_output = self.pass_configs.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) + + command = [ + get_nvcc_compiler(), + "-std=c++17", + "-w", # Disable all warning messages + "-Xcudafe", + "--diag_suppress=177", + "--compiler-options", + "-fPIC", + "-lineinfo", + "--shared", + src.name, + "-lcuda", + "-gencode", + f"arch=compute_{target_arch},code=sm_{target_arch}", + ] + if enable_fast_math: + command += ["--use_fast_math"] + if ptxas_usage_level is not None: + command += [f"--ptxas-options=--register-usage-level={ptxas_usage_level}"] + if verbose_ptxas_output: + command += ["--ptxas-options=--verbose"] + command += [ + "-I" + CUTLASS_INCLUDE_DIR, + ] + + elif is_hip_target(target): + from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR + + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 + libpath = src.name.replace(".cpp", ".so") + rocm_path = find_rocm_path() + arch = get_rocm_arch(rocm_path) + command = [ + "hipcc", + "-std=c++17", + "-fPIC", + f"--offload-arch={arch}", + "--shared", + src.name, + ] + command += [ + "-I" + COMPOSABLE_KERNEL_INCLUDE_DIR, + ] + elif is_cpu_target(target): + from tilelang.contrib.cc import get_cplus_compiler + + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 + libpath = src.name.replace(".cpp", ".so") + + command = [get_cplus_compiler(), "-std=c++17", "-fPIC", "-shared", src.name] + command += [ + "-I" + TILELANG_TEMPLATE_PATH, + ] + else: + raise ValueError(f"Unsupported target: {target}") + + command += [ + "-I" + TILELANG_TEMPLATE_PATH, + ] + + if self.compile_flags: + command += [item for flag in self.compile_flags for item in flag.split() if item not in command] + + command += ["-o", libpath] + + src.write(self.lib_code) + src.flush() + + try: + if verbose: + print(f"compile_lib compilation command: {' '.join(command)}") + ret = subprocess.run(command, timeout=timeout) + except Exception as e: + raise RuntimeError(f"Compile kernel failed because of {e}") from e + + if ret.returncode != 0: + raise RuntimeError(f"Compilation Failed! {command}\n {self.lib_code}") + + self.srcpath = src.name + self.libpath = libpath + + def remove_lib(self): + if self.libpath: + os.remove(self.libpath) + self.libpath = None + + def get_source_path(self): + return self.srcpath + + def get_lib_path(self): + return self.libpath + + def set_lib_path(self, libpath): + self.libpath = libpath + + def set_src_path(self, srcpath): + self.srcpath = srcpath diff --git a/tilelang/original/tilelang/jit/adapter/nvrtc/__init__.py b/tilelang/original/tilelang/jit/adapter/nvrtc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c8abe8d7789faff0ebcebb2e306aaaba436c4a5f --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/nvrtc/__init__.py @@ -0,0 +1,68 @@ +"""NVRTC Backend for TileLang. + +This module provides runtime compilation support using NVIDIA's NVRTC API. +""" + +import logging + +__all__ = ["NVRTCKernelAdapter", "TLNVRTCSourceWrapper", "NVRTCLibraryGenerator", "is_nvrtc_available", "check_nvrtc_available"] + +logger = logging.getLogger(__name__) + +# Check if cuda-python is available +is_nvrtc_available = False +NVRTC_UNAVAILABLE_MESSAGE = ( + "cuda-python is not available, NVRTC backend cannot be used. " + "Please install cuda-python via `pip install cuda-python` " + "if you want to use the NVRTC backend." +) + +try: + import cuda.bindings.driver as cuda # noqa: F401 + import cuda.bindings.nvrtc as nvrtc # noqa: F401 + + is_nvrtc_available = True +except ImportError as e: + logger.debug(f"cuda-python import failed: {e}") + + +def check_nvrtc_available(): + """Check if NVRTC backend is available. + + Raises + ------ + ImportError + If cuda-python is not installed or cannot be imported + """ + if not is_nvrtc_available: + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + +# Conditionally import the adapter +if is_nvrtc_available: + from .adapter import NVRTCKernelAdapter + from .wrapper import TLNVRTCSourceWrapper + from .libgen import NVRTCLibraryGenerator +else: + # Provide a dummy class that raises error on instantiation + class NVRTCKernelAdapter: + """Dummy NVRTCKernelAdapter that raises ImportError on instantiation.""" + + def __init__(self, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + @classmethod + def from_database(cls, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + class TLNVRTCSourceWrapper: + """Dummy TLNVRTCSourceWrapper that raises ImportError on instantiation.""" + + def __init__(self, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + class NVRTCLibraryGenerator: + """Dummy NVRTCLibraryGenerator that raises ImportError on instantiation.""" + + def __init__(self, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) diff --git a/tilelang/original/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/original/tilelang/jit/adapter/nvrtc/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..083c8f215d1cd0f3bd7a36c9b20a6963fe09a531 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/nvrtc/adapter.py @@ -0,0 +1,269 @@ +from __future__ import annotations +import logging +from typing import Any, Callable + +import torch +from tvm import tir +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.engine.param import KernelParam +from tilelang.jit.adapter.wrapper import TLPyWrapper +from tilelang.utils.language import retrieve_func_from_module +from tilelang.utils.target import determine_target +from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available + +from .libgen import NVRTCLibraryGenerator + +logger = logging.getLogger(__name__) + +# Import cuda bindings if available +if is_nvrtc_available: + import cuda.bindings.driver as cuda + + +class NVRTCKernelAdapter(BaseKernelAdapter): + pymodule = None + kernels = {} + + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + check_nvrtc_available() + + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + self.ir_module = func_or_mod + + # Cache parameter information during initialization + # Convert tvm.DataType to torch.dtype for tensor creation + self.param_dtypes = [param.torch_dtype() for param in params] + self.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + # Keep tir.Var for dynamic dimensions + native_shape.append(dim) + else: + native_shape.append(dim) + self.param_shapes.append(native_shape) + + self.dynamic_symbolic_map = self._process_dynamic_symbolic() + + self.target = Target.canon_target(determine_target(target)) + self.verbose = verbose + self.wrapper = TLPyWrapper(self.target) + self.wrapper.assign_optimized_module(self.ir_module) + self.wrapper.assign_pass_configs(pass_configs) + self.wrapper.assign_host_module(host_mod) + self.wrapper.assign_device_module(device_mod) + wrapper_result = self.wrapper.wrap(device_kernel_source) + self.host_func = wrapper_result["host_func"] + self.function_names = wrapper_result["function_names"] + + self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) + self.lib_generator.update_lib_code(self.device_kernel_source) + self.lib_generator.update_host_func(self.host_func) + self.lib_generator.assign_compile_flags(compile_flags) + self.lib_generator.compile_lib() + self.lib_generator.load_lib() + self.libpath = self.lib_generator.libpath + self.pymodule = self.lib_generator.pymodule + culib = self.lib_generator.culib + for name in self.function_names: + result, self.kernels[name] = cuda.cuLibraryGetKernel(culib, bytes(name, "utf-8")) + assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to get kernel: {name}" + + self._post_init() + + @classmethod + def from_database( + cls, + params: list[KernelParam], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + # Cache parameter information during initialization + # Convert tvm.DataType to torch.dtype for tensor creation + adapter.param_dtypes = [param.torch_dtype() for param in params] + adapter.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + # Keep tir.Var for dynamic dimensions + native_shape.append(dim) + else: + native_shape.append(dim) + adapter.param_shapes.append(native_shape) + + adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic() + + adapter.target = Target.canon_target(determine_target(target)) + adapter.verbose = verbose + adapter.lib_generator = NVRTCLibraryGenerator(adapter.target, adapter.verbose) + adapter.lib_generator.assign_compile_flags(compile_flags) + adapter.lib_generator.load_lib(lib_path=kernel_lib_path) + adapter.pymodule = adapter.lib_generator.pymodule + adapter.function_names = adapter.pymodule._function_names + + culib = adapter.lib_generator.culib + for name in adapter.function_names: + result, adapter.kernels[name] = cuda.cuLibraryGetKernel(culib, bytes(name, "utf-8")) + assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to get kernel: {name}" + + adapter._post_init() + return adapter + + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]: + """Extract information about dynamic shapes from the TIR function. + + Maps symbolic variables to their corresponding (buffer_index, shape_dimension) + for runtime shape resolution. + + Returns + ------- + Dict[tir.Var, Tuple[int, int]] + Mapping from symbolic variable to (buffer_index, shape_dimension) + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map = {} + for i, param in enumerate(params): + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map): + dynamic_symbolic_map[shape] = (i, j) + return dynamic_symbolic_map + + def get_kernel_source(self, kernel_only: bool = True) -> str | None: + """Get the CUDA kernel source code. + + Returns + ------- + Optional[str] + The kernel source code, or None if not available + """ + if kernel_only: + return self.device_kernel_source + else: + return self.host_func + + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): + """Low-level function to call the compiled CUDA kernel.""" + return self.pymodule.call(self.kernels, *args, stream=stream) + + def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): + """High-level wrapper for kernel execution. + + Handles: + 1. Input validation + 2. Output tensor allocation + 3. Dynamic shape resolution + 4. CUDA stream management + + Args: + ins: Input PyTorch tensors + stream: Optional CUDA stream for asynchronous execution + + Returns: + Single tensor or list of tensors containing the kernel results + """ + if len(ins) + len(self.result_idx) != len(self.params): + raise ValueError( + f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" + ) + ins_idx = 0 + args = [] + + # tensor pointers + for i in range(len(self.params)): + if i in self.result_idx: + dtype = self.param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in self.param_shapes[i]: + if isinstance(s, tir.Var): + ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s] + shape.append(ins[ref_tensor_idx].shape[ref_shape_idx]) + else: # Already converted to Python int during initialization + shape.append(s) + device = ins[0].device if len(ins) > 0 else torch.cuda.current_device() + tensor = torch.empty(*shape, dtype=dtype, device=device) + else: + tensor = ins[ins_idx] + ins_idx += 1 + args.append(tensor) + + # dynamic symbolics + for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): + args.append(ins[buffer_idx].shape[shape_idx]) + + # if stream is not None, we need to pass the stream to the library + if stream is None: + if str(self.target).startswith("cuda") and torch.cuda.is_available(): + stream = torch.cuda.current_stream().cuda_stream + else: + stream = 0 + + self._forward_from_prebuild_lib(*args, stream=stream) + + if len(self.result_idx) == 1: + return args[self.result_idx[0]] + else: + return [args[i] for i in self.result_idx] + + def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]: + """Convert to a PyTorch-compatible function. + + Returns + ------- + Callable[..., Union[torch.Tensor, List[torch.Tensor]]] + A callable function that takes tensors and returns tensor(s) + """ + return self._wrap_forward_from_prebuild_lib + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) diff --git a/tilelang/original/tilelang/jit/adapter/nvrtc/libgen.py b/tilelang/original/tilelang/jit/adapter/nvrtc/libgen.py new file mode 100644 index 0000000000000000000000000000000000000000..406cc44d97f8eee513c8c427961abf908289a665 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/nvrtc/libgen.py @@ -0,0 +1,233 @@ +"""NVRTC Library Generator for TileLang. + +Compiles CUDA kernels at runtime using NVRTC and manages resulting binaries. + +Why NVRTC instead of nvcc: +- No offline compilation step, enables true JIT workflows +- Works without CUDA toolkit installed (only requires driver) +- Allows kernel specialization based on runtime parameters + +Key responsibilities: +- Compile CUDA source to cubin using NVRTC API +- Generate accompanying Python launcher code +- Load compiled cubin and extract kernel handles +- Manage library lifecycle (load/unload) +""" + +from __future__ import annotations +import importlib +import logging +import os.path as osp +import platform +import tempfile +from types import ModuleType + +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.jit.adapter.libgen import LibraryGenerator +from tilelang.jit.adapter.utils import is_cuda_target +from tilelang.jit.adapter.nvrtc import is_nvrtc_available, NVRTC_UNAVAILABLE_MESSAGE + +logger = logging.getLogger(__name__) + +if is_nvrtc_available: + import cuda.bindings.driver as cuda + from tilelang.contrib.nvrtc import compile_cuda +else: + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + +class NVRTCLibraryGenerator(LibraryGenerator): + """Runtime compiler and loader for NVRTC-compiled CUDA kernels. + + Lifecycle: + 1. compile_lib(): CUDA source → cubin + Python launcher + 2. load_lib(): cubin → loaded library + kernel handles + 3. pymodule.call(): Execute kernels via Python launcher + 4. __del__: Cleanup (unload library) + + Why three files (cu, cubin, py): + - .cu: Source for debugging, kept in temp directory + - .cubin: Compiled binary, loaded by CUDA driver + - .py: Launch code, imported as Python module + + Attributes: + host_func: Generated Python launch code (from wrapper) + culib: CUDA library handle (CUlibrary) + pymodule: Imported Python module containing call() function + """ + + host_func: str | None = None + culib: cuda.CUlibrary | None = None + pymodule: ModuleType | None = None + pypath: str | None = None + + def __init__(self, target: Target, verbose: bool = False): + """Initialize NVRTC library generator. + + Args: + target: Compilation target (must be CUDA) + verbose: Enable verbose compilation output + """ + super().__init__(target, verbose) + + @staticmethod + def import_from_file(module_name, file_path): + """Dynamically import Python module from file path. + + Standard importlib pattern for loading modules outside sys.path. + Used to import generated .py launcher code from temp directory. + + Args: + module_name: Name to assign to imported module + file_path: Absolute path to .py file + + Returns: + Imported module object + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Failed to import module from file: {file_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def update_host_func(self, host_func: str): + """Store generated Python launch code for later file write. + + Called by adapter after wrapper generates the launch code. + This is the bridge between code generation and file output. + + Args: + host_func: Python source code containing call() function + """ + self.host_func = host_func + + def load_lib(self, lib_path: str | None = None): + """Load compiled cubin and Python launcher into memory. + + Why two loads: + 1. Import Python module for launch logic + 2. Load cubin via CUDA Driver API for kernel handles + + Context synchronization: CUDA context must be current before loading. + If not, use torch.cuda.synchronize() to establish context. + + Args: + lib_path: Path to .cubin file (optional, uses self.libpath if None) + + Side effects: + - Sets self.pymodule to imported Python module + - Sets self.culib to CUDA library handle + """ + if lib_path is None: + lib_path = self.libpath + else: + self.libpath = lib_path + + self.pypath = lib_path.replace(".cubin", ".py") + self.pymodule = self.import_from_file("kernel", self.pypath) + + # Ensure the context is valid + ctx = cuda.cuCtxGetCurrent()[1] + if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS: + import torch + + torch.cuda.synchronize() + + result, self.culib = cuda.cuLibraryLoadFromFile(bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) + if result != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}") + + def compile_lib(self, timeout: float | None = None): + """Compile CUDA source to cubin using NVRTC and write output files. + + Output artifacts (all in temp directory): + - .cu: Source code (for debugging) + - .cubin: Compiled binary (for execution) + - .py: Python launcher (for calling kernels) + + Include paths setup: + - TileLang templates: kernel primitives and utilities + - CUTLASS: optimized GEMM/tensor ops + - CUDA headers: driver/runtime APIs + + Why architecture detection: + ARM64 servers (SBSA) have different header paths than x86_64. + + Args: + timeout: Compilation timeout in seconds (currently unsupported by NVRTC compiler) + + Side effects: + - Writes .cu, .cubin, .py files to temp directory + - Sets self.srcpath, self.libpath, self.pypath + """ + target = self.target + verbose = self.verbose + if is_cuda_target(target): + from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH + + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + libpath = src.name.replace(".cu", ".cubin") + + project_root = osp.join(osp.dirname(__file__), "..", "..") + if CUTLASS_INCLUDE_DIR is None: + cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) + else: + cutlass_path = CUTLASS_INCLUDE_DIR + + if TILELANG_TEMPLATE_PATH is None: + tl_template_path = osp.abspath(osp.join(project_root, "src")) + else: + tl_template_path = TILELANG_TEMPLATE_PATH + + cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda" + __CUDACC_VER_MAJOR__ = cuda.CUDA_VERSION // 1000 + + # Determine target architecture + machine = platform.machine() + target_arch = "sbsa-linux" if machine in ("aarch64", "arm64") else "x86_64-linux" + + options = [ + f"-I{tl_template_path}", + f"-I{cutlass_path}", + f"-I{cuda_home}/include", + f"-I{cuda_home}/targets/{target_arch}/include", + f"-I{cuda_home}/targets/{target_arch}/include/cccl", + f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}", + ] + if self.compile_flags: + options += [item for flag in self.compile_flags for item in flag.split() if item not in options] + + cubin_bytes = compile_cuda(self.lib_code, target_format="cubin", options=options, verbose=verbose) + with open(libpath, "wb") as f: + f.write(cubin_bytes) + + src.write(self.lib_code) + src.flush() + + self.srcpath = src.name + self.libpath = libpath + self.pypath = src.name.replace(".cu", ".py") + if self.host_func is None: + raise RuntimeError("Host function is not set, please call update_host_func() first.") + with open(self.pypath, "w") as f: + f.write(self.host_func) + else: + raise ValueError(f"Unsupported target: {target}") + + def __del__(self): + """Cleanup: unload CUDA library when object is destroyed. + + Critical for resource management - CUDA libraries consume GPU memory. + Failure to unload is logged but not raised (destructor can't fail). + + Why explicit unload: + Python GC doesn't know about GPU resources, must release manually. + """ + if self.culib: + result = cuda.cuLibraryUnload(self.culib)[0] + if result != cuda.CUresult.CUDA_SUCCESS: + logger.warning(f"Failed to unload library: {self.libpath}") + self.culib = None diff --git a/tilelang/original/tilelang/jit/adapter/nvrtc/wrapper.py b/tilelang/original/tilelang/jit/adapter/nvrtc/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2316823ec61760133d7fcd01dff8dd670e805aa4 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/nvrtc/wrapper.py @@ -0,0 +1,581 @@ +"""NVRTC Source Wrapper for TileLang. + +Generates Python runtime code for launching CUDA kernels compiled via NVRTC. + +Why this exists: +- NVRTC compiles kernels at runtime, needs Python launch code (not C++) +- TMA descriptors must be initialized once per unique buffer, not per kernel +- L2 cache policies require explicit CUDA Driver API setup/teardown + +Key design: +- Two-pass generation: collect all descriptors first, then generate launches +- Dict-based deduplication ensures TMA descriptors created only once +- Generates pure Python using cuda.bindings.driver for zero C++ dependency +""" + +from __future__ import annotations +from typing import Any, ClassVar + +from tvm import IRModule +from tvm.target import Target +from tvm.tir.stmt_functor import post_order_visit + +from tilelang import tvm as tvm +from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper +from tilelang.jit.adapter.utils import match_declare_kernel, pythonic_expr, parse_function_call_args, parse_tma_descriptor_args + +PREDEF_HOST_FUNC_PY = """ +from cuda.bindings.driver import ( + CUtensorMapDataType, + CUtensorMapInterleave, + CUtensorMapSwizzle, + CUtensorMapL2promotion, + CUtensorMapFloatOOBfill, + cuTensorMapEncodeTiled, + cuTensorMapEncodeIm2col, + CUresult, + cuKernelSetAttribute, + CUfunction_attribute, + CUdevice, + CUlaunchConfig, + cuLaunchKernelEx, + cuuint64_t, + cuuint32_t, + CUkernel, +) +import ctypes + +_function_names = {} + +def call({}): + {} +""" + +TMA_DESC_INIT_FUNC_PY = """ + {0}_type = CUtensorMapDataType({1}) + {0}_tensorRank = {2} + {0}_globalAddress = {3}.data_ptr() + {0}_globalDim = [{4}] + {0}_globalStride = [{5}][1:] + {0}_boxDim = [{6}] + {0}_elementStrides = [{7}] + {0}_interleave = CUtensorMapInterleave({8}) + {0}_swizzle = CUtensorMapSwizzle({9}) + {0}_l2Promotion = CUtensorMapL2promotion({10}) + {0}_oobFill = CUtensorMapFloatOOBfill({11}) + + res, {0} = cuTensorMapEncodeTiled( + {0}_type, + {0}_tensorRank, + {0}_globalAddress, + {0}_globalDim, + {0}_globalStride, + {0}_boxDim, + {0}_elementStrides, + {0}_interleave, + {0}_swizzle, + {0}_l2Promotion, + {0}_oobFill, + ) + + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}") +""" + +TMA_IM2COL_DESC_INIT_FUNC_PY = """ + {0}_type = CUtensorMapDataType({1}) + {0}_tensorRank = {2} + {0}_globalAddress = {3}.data_ptr() + {0}_globalDim = [{4}] + {0}_globalStride = [{5}][1:] + {0}_elementStrides = [{6}] + {0}_lowerCorner = [{7}] + {0}_upperCorner = [{8}] + {0}_channelsPerPixel = {9} + {0}_pixelsPerColumn = {10} + {0}_interleave = CUtensorMapInterleave({11}) + {0}_swizzle = CUtensorMapSwizzle({12}) + {0}_l2Promotion = CUtensorMapL2promotion({13}) + {0}_oobFill = CUtensorMapFloatOOBfill({14}) + + res, {0} = cuTensorMapEncodeIm2col( + {0}_type, + {0}_tensorRank, + {0}_globalAddress, + {0}_globalDim, + {0}_globalStride, + {0}_lowerCorner, + {0}_upperCorner, + {0}_channelsPerPixel, + {0}_pixelsPerColumn, + {0}_elementStrides, + {0}_interleave, + {0}_swizzle, + {0}_l2Promotion, + {0}_oobFill, + ) + + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}") +""" + +L2_PERSISTENT_MAP_CREATE_HANDLE_PY = """ + from cuda.bindings.driver import ( + CUstreamAttrValue, + CUstreamAttrID, + CUlimit, + CUaccessProperty, + cuCtxGetLimit, + cuCtxSetLimit, + cuStreamSetAttribute, + cuCtxResetPersistingL2Cache, + ) + + stream_attribute = CUstreamAttrValue() + res, init_persisting_l2_cache_size = cuCtxGetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE) + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to get L2 cache size limit: {{res}}") +""" + +L2_PERSISTENT_MAP_INIT_FUNC_PY = """ + stream_attribute.accessPolicyWindow.hitRatio = {1} + stream_attribute.accessPolicyWindow.hitProp = CUaccessProperty.CU_ACCESS_PROPERTY_PERSISTING + stream_attribute.accessPolicyWindow.missProp = CUaccessProperty.CU_ACCESS_PROPERTY_STREAMING + + res = cuCtxSetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE, {2})[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to set L2 cache size limit: {{res}}") + + stream_attribute.accessPolicyWindow.base_ptr = {0}.data_ptr() + stream_attribute.accessPolicyWindow.num_bytes = {2} + + res = cuStreamSetAttribute(stream, CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW, stream_attribute)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to set stream L2 access policy: {{res}}") +""" + +L2_PERSISTENT_MAP_RESET_HANDLE_PY = """ + stream_attribute.accessPolicyWindow.num_bytes = 0 + res = cuStreamSetAttribute(stream, CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW, stream_attribute)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to reset stream L2 access policy: {{res}}") + + res = cuCtxResetPersistingL2Cache()[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to reset L2 cache: {{res}}") + + res = cuCtxSetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE, init_persisting_l2_cache_size)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to restore L2 cache size limit: {{res}}") +""" + +KERNEL_LAUNCH_FUNC_PY = """ + res = cuKernelSetAttribute( + CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + {7}, + kernels["{0}"], + CUdevice({10}) + )[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}") + + config = CUlaunchConfig() + config.gridDimX = {1} + config.gridDimY = {2} + config.gridDimZ = {3} + config.blockDimX = {4} + config.blockDimY = {5} + config.blockDimZ = {6} + config.sharedMemBytes = {7} + config.hStream = stream + + arg_values = {8} + arg_types = {9} + + res = cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to launch kernel {0}: {{res}}") +""" + + +class TLNVRTCSourceWrapper(TLCUDASourceWrapper): + """NVRTC backend wrapper: generates Python kernel launch code. + + Core responsibility: transform TVM IRModule into executable Python function + that initializes resources (TMA descriptors, L2 cache) and launches kernels + via CUDA Driver API. + + Data flow: + IRModule → collect kernel metadata → deduplicate resources → + generate Python code → executable function + + Why Python generation instead of C++: + NVRTC workflow requires runtime compilation, Python is the natural host. + Using cuda.bindings.driver eliminates C++ wrapper complexity. + """ + + _TYPE_MAP: ClassVar[dict[str, str]] = { + "float32": "ctypes.c_float", + "float16": "ctypes.c_uint16", + "bfloat16": "ctypes.c_uint16", + "float8_e4m3": "ctypes.c_uint8", + "float8_e4m3fn": "ctypes.c_uint8", + "float8_e5m2": "ctypes.c_uint8", + "float64": "ctypes.c_double", + "int64": "ctypes.c_int64", + "int32": "ctypes.c_int32", + "uint32": "ctypes.c_uint32", + "bool": "ctypes.c_bool", + "int8": "ctypes.c_int8", + "uint8": "ctypes.c_uint8", + "int16": "ctypes.c_int16", + "uint16": "ctypes.c_uint16", + "uchar": "ctypes.c_uint8", + } + + _generated_host_func: str | None = None + + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + """Initialize NVRTC wrapper with compiled IR modules. + + Args: + scheduled_ir_module: TVM IR after scheduling passes + source: Generated CUDA C++ source code + target: Compilation target (should be NVRTC-compatible) + device_mod: Device-side IR module (kernel functions) + host_mod: Host-side IR module (launch logic) + pass_configs: Optional compiler pass configurations + """ + super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) + + @property + def host_func(self): + """Override parent's host_func to return generated Python code.""" + if self._generated_host_func is not None: + return self._generated_host_func + return super().host_func + + @host_func.setter + def host_func(self, value): + """Allow setting generated host function code.""" + self._generated_host_func = value + + def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + """Convert TVM expression to Python string, ignoring casts. + + Casts are noise in generated Python code - Python is dynamically typed. + """ + return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True, floor_div_op="//") + + def create_dispatch_func(self, code, function_informations): + """Generate Python dispatch function that launches multiple CUDA kernels. + + Why two-pass design: + Pass 1: Collect TMA descriptors from all kernels into shared dicts + Pass 2: Generate code - descriptors first (deduplicated), then launches + + Single-pass would create duplicate descriptors for each kernel. + Dict naturally deduplicates by descriptor name. + + Args: + code: CUDA C++ source containing kernel declarations + function_informations: Dict mapping kernel names to metadata + (grid/block dims, params, shared memory size) + + Returns: + Python source code defining a call() function that: + 1. Initializes L2 cache policies (if needed) + 2. Creates TMA descriptors once per unique buffer + 3. Launches each kernel with cuLaunchKernelEx + 4. Resets L2 cache policies (if needed) + """ + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + function_args = [{"name": "kernels", "type": "dict[str, CUkernel]"}] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + if param in self.prim_func.buffer_map: + buffer = self.prim_func.buffer_map[param] + function_args.append( + { + "name": buffer.data.name, + "type": "ctypes.c_void_p", + } + ) + elif isinstance(param, tvm.tir.Var): + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) + else: + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") + # Add dynamic symbols as integer arguments + for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: + if dyn_sym not in [arg["name"] for arg in function_args]: + function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) + + function_args.append(self.get_stream_type()) + + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['name']}" for arg in function_args]) + + # Check if any function needs L2 Persistent Map + has_l2_persistent_map = False + for function_name, _ in function_informations.items(): + if function_name in self.l2_persistent_map: + has_l2_persistent_map = True + break + + desc_name_map: dict[str, str] = {} + desc_name_var_map: dict[str, tvm.tir.Var] = {} + device_index = 0 + kernel_launch_code = """""" + if has_l2_persistent_map: + kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE_PY + + # First pass: collect all TMA descriptors from all kernels to avoid duplication + kernel_info_list = [] + for function_name, function_info in function_informations.items(): + block_info = function_info["block_info"] + grid_info = function_info["grid_info"] + dynamic_smem_buf = function_info["dynamic_smem_buf"] + function_params = function_info["function_params"] + + # Find the location of the global kernel function in the code + index = match_declare_kernel(code, function_name + "(") + + # Analyze the function declaration to prepare for argument extraction + declaration = code[index:].split(";")[0] + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + + # Transform function for NVRTC: returns (arg_value, arg_type) tuples + def transform_nvrtc_arg(name: str, arg_type: str): + if arg_type == "ctypes.c_void_p": + return (f"{name}.data_ptr()", arg_type) + return (name, arg_type) + + call_args = parse_function_call_args( + declaration, function_args, function_params, desc_name_map, desc_name_var_map, transform_nvrtc_arg + ) + + for arg_name, arg_type in call_args: + if arg_type == "ctypes.c_void_p": + device_index = f"{arg_name.replace('.data_ptr()', '')}.device.index" + break + + # Store kernel info for second pass + kernel_info_list.append( + { + "function_name": function_name, + "block_info": block_info, + "grid_info": grid_info, + "dynamic_smem_buf": dynamic_smem_buf, + "call_args": call_args, + "device_index": device_index, + } + ) + + # Generate TMA descriptor initialization code once for all kernels + kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) + + # Second pass: generate kernel launch code for each kernel + for kernel_info in kernel_info_list: + function_name = kernel_info["function_name"] + block_info = kernel_info["block_info"] + grid_info = kernel_info["grid_info"] + dynamic_smem_buf = kernel_info["dynamic_smem_buf"] + call_args = kernel_info["call_args"] + device_index = kernel_info["device_index"] + + arg_names = ", ".join([arg[0] for arg in call_args]) + arg_types = ", ".join([arg[1] for arg in call_args]) + smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf + + # Generate L2 persistent map initialization for this function + init_l2_persistent_map = self.generate_l2_persistent_map(function_name) + kernel_launch_code += init_l2_persistent_map + + # Generate kernel launch code + kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format( + function_name, + self._pythonic_expr(grid_info[0]), + self._pythonic_expr(grid_info[1]), + self._pythonic_expr(grid_info[2]), + self._pythonic_expr(block_info[0]), + self._pythonic_expr(block_info[1]), + self._pythonic_expr(block_info[2]), + smem_str, + arg_names, + arg_types, + device_index, + ) + + # Reset L2 persistent map after all kernel execution + if has_l2_persistent_map: + kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY + + # Wrap the kernel dispatch logic in an external C function + host_func = PREDEF_HOST_FUNC_PY.format(repr(list(function_informations.keys())), def_args, kernel_launch_code) + return host_func + + def generate_l2_persistent_map(self, function_name: str) -> str: + """Generate Python code to configure L2 cache persistence for a kernel. + + L2 persistence pins frequently-accessed data in L2 cache to reduce + memory bandwidth. Requires explicit setup via CUDA stream attributes. + + Args: + function_name: Kernel name to check for L2 persistence config + + Returns: + Python code that sets stream access policy window, or empty + string if no L2 persistence configured for this kernel. + """ + if function_name not in self.l2_persistent_map: + return "" + init_l2_persistent_map = "" + for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items(): + # Get persisting_l2_cache_max_size + from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() + try: + num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) + except TypeError: + # as size_in_bytes may be a symbolic expression + num_bytes = persisting_l2_cache_max_size + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + + return init_l2_persistent_map + + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + """Generate Python code to initialize TMA descriptors. + + TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects + that describe memory layout for async copies. Must be created on host + before kernel launch. + + Args: + desc_name_map: Maps descriptor variable names to buffer names + desc_name_var_map: Maps descriptor names to TVM variables + + Returns: + Python code that calls cuTensorMapEncodeTiled/Im2col for each + unique descriptor. Empty string if no TMA descriptors needed. + """ + tma_descriptor_init = "" + if self.tma_descriptor_args is None: + return tma_descriptor_init + + # Parse TMA descriptor arguments using the common utility + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) + + # Generate Python code from parsed parameters + for params in parsed_params: + if not params.is_img2col: + tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format( + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), + ", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)), + ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) + else: + tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format( + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), + ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), + ", ".join(params.lower_corner), + ", ".join(params.upper_corner), + params.smem_box_channel, + params.smem_box_pixel, + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) + + return tma_descriptor_init + + def update_lib_code(self, code: str): + """Update library code and generate host dispatch function. + + Entry point for code generation. Walks the host IR to extract kernel + call sites, matches them with device kernels, then generates Python + dispatch code via create_dispatch_func(). + + Args: + code: CUDA C++ source code containing compiled kernels + + Returns: + The same code string (stored in self.lib_code). Side effect: + sets self.host_func to generated Python dispatcher. + """ + # Update the library code with the given code string + self.lib_code = code + + # Organize function information for code generation + function_informations = {} + for function_name in self.function_names: + # Do not update function with dispatch host function + if (function_name not in self.block_info) or (function_name not in self.grid_info): + continue + + assert function_name in self.device_mod, f"Function {function_name} not found in device module" + device_func = self.device_mod[function_name] + kernel_params_cnt = len(device_func.params) + function_params: list[str] | None = None + + def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): + nonlocal function_params + if isinstance(node, tvm.tir.Call): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + return + args = node.args + if not args or args[0] != fn: + return + if len(args) < 1 + param_cnt: + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] + + post_order_visit(self.host_func.body, visitor) + assert function_params is not None, "function_params should not be None" + + function_informations[function_name] = { + "function_name": function_name, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + "function_params": function_params, + } + + # Create the host function wrapper for the CUDA kernel + self.host_func = self.create_dispatch_func(code, function_informations) + return self.lib_code + + def get_stream_type(self) -> dict[str, str]: + """Return stream parameter spec for Python signature. + + NVRTC backend uses raw int for stream handle (not cudaStream_t pointer). + Default to 0 (NULL stream) for convenience. + """ + return {"name": "stream=0", "type": "int"} diff --git a/tilelang/original/tilelang/jit/adapter/torch/__init__.py b/tilelang/original/tilelang/jit/adapter/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f688993d0e563f1b14a7c4ecbe91ce4a9bd344b3 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/torch/__init__.py @@ -0,0 +1,3 @@ +from .metal import MetalKernelAdapter + +__all__ = ["MetalKernelAdapter"] diff --git a/tilelang/original/tilelang/jit/adapter/torch/metal.py b/tilelang/original/tilelang/jit/adapter/torch/metal.py new file mode 100644 index 0000000000000000000000000000000000000000..4690cf59bda7cd1907c130d5a8d1446466e097b5 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/torch/metal.py @@ -0,0 +1,71 @@ +from __future__ import annotations +from functools import wraps +from typing import Callable + +import torch +from tvm import tir + +from tilelang import tvm as tvm + +from ..base import BaseKernelAdapter +from tilelang.engine.param import KernelParam + + +class MetalKernelAdapter(BaseKernelAdapter): + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + # target: Union[str, Target], + func_or_mod: tir.PrimFunc | tvm.IRModule, + # host_mod: Optional[tvm.IRModule] = None, + device_mod: tvm.IRModule | None = None, + kernel_global_source: str | None = None, + verbose: bool = False, + # pass_configs: Optional[Dict[str, Any]] = None, + # compile_flags: Optional[List[str]] = None + ): + self.kernel_global_source = kernel_global_source + if isinstance(func_or_mod, tir.PrimFunc): + func_name = func_or_mod.attrs["global_symbol"] + else: + func_name = func_or_mod.__name__ + self.kernel_name = func_name + "_kernel" + self.verbose = verbose + + self.block_info = [1, 1, 1] + self.grid_info = [1, 1, 1] + + for var, func in device_mod.functions.items(): + assert var.name_hint == self.kernel_name + thread_extent = func.attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + self.block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + self.grid_info["xyz".index(tag[-1])] = extent + break + else: + raise AssertionError(f"no kernel with name {func_name}") + + # print(self.block_info, self.grid_info) + super().__init__(func_or_mod, result_idx=result_idx, params=params) + + _kernel = None + + def _convert_torch_func(self) -> Callable: + if self._kernel is None: + _kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name) + _threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)] + + @wraps(_kernel) + def launcher(*args: torch.Tensor): + return _kernel( + *args, + threads=_threads, + group_size=self.block_info, + ) + + self._kernel = launcher + + return self._kernel diff --git a/tilelang/original/tilelang/jit/adapter/tvm_ffi.py b/tilelang/original/tilelang/jit/adapter/tvm_ffi.py new file mode 100644 index 0000000000000000000000000000000000000000..fdba92c210858eafb31ecf25b96956da8e8f501b --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/tvm_ffi.py @@ -0,0 +1,314 @@ +"""Utilities to adapt TVM FFI kernels to Torch tensors. + +This adapter intentionally captures PyTorch's current CUDA stream and device +via light-weight callables so that, when the wrapped function is invoked, +the execution observes the same stream context as the active Torch code. +On non-CUDA builds, the stream/device fall back to 0/CPU semantics. +""" + +from __future__ import annotations + +from typing import Callable, Any + +import torch +from tilelang import tvm +from tvm import runtime, tir +from tvm.target import Target +from tvm.relax import TensorType +from tilelang.utils.target import determine_target +from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.utils.language import retrieve_func_from_module +from tilelang.engine.param import KernelParam + + +class TVMFFIKernelAdapter(BaseKernelAdapter): + """Adapter that runs a TVM runtime.Executable with Torch tensors. + + Notes + - We capture the "current" PyTorch CUDA stream/device as thunks (callables) + rather than materializing them at construction time. This ensures the + actual stream/device is read just-in-time when the function runs, matching + the user's current Torch context (e.g., after a stream guard/switch). + - The stream pointer returned is a raw CUDA stream handle compatible with + TVM's device API; on CPU or when CUDA is unavailable, we return 0. + """ + + # Class attributes to store compiled kernel information + target: str | Target = "cuda" + ir_module: tvm.IRModule | None = None + # The global source code of the kernel -> global means the source code of the kernel + # that is not wrapped by the wrapper code + host_kernel_source: str | None = None + device_kernel_source: str | None = None + executable: tvm.runtime.Executable | None = None + # Pass configs for the compiler + pass_configs: dict[str, Any] | None = None + # host_mod + host_mod: tvm.IRModule | None = None + # device_mod + device_mod: tvm.IRModule | None = None + # rt_mod + rt_mod: tvm.runtime.Module | None = None + # Maps symbolic variables to their corresponding buffer and shape indices + dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None + + # Stream/device functors are inherited from BaseKernelAdapter + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + rt_mod: tvm.runtime.Module | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + """Initialize the adapter with the given TIR function or module. + + Args: + params: List of tensor types for inputs/outputs + result_idx: Indices of output tensors + target: Target platform (e.g., 'cuda') + func_or_mod: TIR function or module to be compiled + verbose: Enable verbose logging + """ + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.host_kernel_source = host_kernel_source + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + self.ir_module = func_or_mod + + self.target = Target.canon_target(determine_target(target)) + + self.host_mod = host_mod + self.device_mod = device_mod + self.rt_mod = rt_mod + self.verbose = verbose + self.pass_configs = pass_configs + self.compile_flags = compile_flags + self.dynamic_symbolic_map = self._process_dynamic_symbolic() + + self._post_init() + + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]: + """Extract information about dynamic shapes from the TIR function. + + Maps symbolic variables to their corresponding (id, buffer_index, dimension) + for runtime shape resolution. + id represents shape or stride, 0 represents shape, 1 represents stride + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map = {} + for i, param in enumerate(params): + if isinstance(param, tir.Var) and (param not in dynamic_symbolic_map): + dynamic_symbolic_map[param] = (2, i, -1) + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): + dynamic_symbolic_map[shape] = (0, i, j) + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, stride in enumerate(buffer.strides): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): + dynamic_symbolic_map[stride] = (1, i, j) + return dynamic_symbolic_map + + def _convert_torch_func(self) -> Callable[..., Any]: + # Capture thunks that reflect Torch's current stream and device. + # These are evaluated at call time to align TVM execution with the + # caller's active PyTorch stream/device. + # current_stream_functor = self.get_current_stream_functor() + current_device_functor = self.get_current_device_functor() + + # Convert TVM types to native Python types during initialization + # Convert tvm.DataType to torch.dtype for tensor creation + param_dtypes = [param.torch_dtype() for param in self.params] + # Convert TVM shape arrays to native Python lists + param_shapes = [] + + for param in self.params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + native_shape.append(dim) # Keep tir.Var for dynamic dimensions + else: + native_shape.append(dim) + param_shapes.append(native_shape) + + if self.executable is None: + self.executable = runtime.Executable(self.rt_mod) + + dynamic_symbolic_map = self._process_dynamic_symbolic() + executable = self.executable + + # Prepare helpers for friendly dtype error messages + prim_func = self.prim_func + buffer_map = prim_func.buffer_map + params = prim_func.params + # Expected dtype string per parameter index (for buffers only) + expected_dtype_strs: list[str | None] = [] + # Track whether each param is a buffer (has dtype) vs scalar + is_buffer_param: list[bool] = [] + for p in params: + if p in buffer_map: + expected_dtype_strs.append(str(buffer_map[p].dtype)) + is_buffer_param.append(True) + else: + expected_dtype_strs.append(None) + is_buffer_param.append(False) + + # Map torch dtype to TVM-style dtype string + def torch_dtype_to_tvm_str(dtype: torch.dtype) -> str: + try: + import torch as _torch + except Exception: # pragma: no cover + # Fallback, though torch should always be available here + return str(dtype) + fp8_e4m3fn = getattr(_torch, "float8_e4m3fn", None) + fp8_e4m3fnuz = getattr(_torch, "float8_e4m3fnuz", None) + fp8_e5m2 = getattr(_torch, "float8_e5m2", None) + fp8_e5m2fnuz = getattr(_torch, "float8_e5m2fnuz", None) + if fp8_e4m3fn is not None and dtype == fp8_e4m3fn: + return "float8_e4m3" + if fp8_e4m3fnuz is not None and dtype == fp8_e4m3fnuz: + return "float8_e4m3fnuz" + if fp8_e5m2 is not None and dtype == fp8_e5m2: + return "float8_e5m2" + if fp8_e5m2fnuz is not None and dtype == fp8_e5m2fnuz: + return "float8_e5m2" + # Strip torch. prefix for readability + s = str(dtype) + return s[6:] if s.startswith("torch.") else s + + def func(*inputs: torch.Tensor | Any): + # Validate input count strictly + expected_inputs = len(self.params) - len(self.result_idx) + if len(inputs) != expected_inputs: + raise ValueError(f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.") + + # Resolve the device used for outputs. Prefer the first tensor input's device + # if available, otherwise use PyTorch's current device. + out_device: torch.device | None = None + + # Stitch the full positional argument list expected by the TVM executable + ins_idx: int = 0 + tensor_list: list[torch.Tensor] = [] + + # Prepare input and output tensors + for i in range(len(self.params)): + if i in self.result_idx: + dtype = param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in param_shapes[i]: + if isinstance(s, tir.Var): + for key in dynamic_symbolic_map: + if str(s) == str(key): + ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key] + if ref_id == 2: + shape.append(inputs[ref_tensor_idx]) + elif ref_id == 0: + shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) + elif ref_id == 1: + shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx]) + else: # Already converted to Python int during initialization + shape.append(s) + + if out_device is None: + out_device = current_device_functor() + + if len(shape) == 0: + param_name = self.params[i].name if hasattr(self.params[i], "name") else f"parameter_{i}" + raise ValueError( + f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. " + f"Expected shape: {shape}" + ) + tensor = torch.empty(*shape, dtype=dtype, device=out_device) + else: + tensor = inputs[ins_idx] + ins_idx += 1 + tensor_list.append(tensor) + + executable(*tensor_list) + + # Return outputs in the requested form + if len(self.result_idx) == 1: + return tensor_list[self.result_idx[0]] + return [tensor_list[i] for i in self.result_idx] + + return func + + @classmethod + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source + adapter.pass_configs = pass_configs + + if isinstance(func_or_mod, tir.PrimFunc): + adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + target = determine_target(target, return_object=True) + adapter.target = Target.canon_target(determine_target(target)) + + adapter.verbose = verbose + adapter.executable = runtime.load_module(kernel_lib_path) + adapter._post_init() + return adapter + + def get_host_source(self): + """Returns the source code of the host module.""" + if self.host_kernel_source is not None: + return self.host_kernel_source + return self.rt_mod.inspect_source() + + def get_device_source(self): + """Returns the source code of the device module.""" + if self.device_kernel_source is not None: + return self.device_kernel_source + return self.rt_mod.imports[0].inspect_source() + + def get_kernel_source(self, kernel_only: bool = False): + """Returns the source code of the compiled kernel.""" + if kernel_only: + return self.get_device_source() + else: + return self.get_device_source() + "\n\n" + self.get_host_source() + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) diff --git a/tilelang/original/tilelang/jit/adapter/utils.py b/tilelang/original/tilelang/jit/adapter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d43adf840aafec7bc2f0b2ae2e01a3aa84f272a1 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/utils.py @@ -0,0 +1,494 @@ +from __future__ import annotations + +import re +from typing import Literal, Callable, Any +from tilelang import tvm as tvm +from tvm import IRModule, tir +from tvm.target import Target +from tilelang.engine.lower import ( + get_device_call, + get_host_call, + determine_target, + canon_target_host, + is_cpu_device_backend, +) +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) + + +def match_global_kernel(source: str, annotation: str = "__global__") -> int: + pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" + for line in source.split("\n"): + if annotation in line: + matched = re.findall(pattern, line) + if len(matched) >= 1: + return source.index(matched[0]) + raise ValueError("No global kernel found in the source code") + + +def match_declare_kernel(source: str, annotation: str = "__global__") -> int: + pattern = r"__global__\s+void\s+(?:__launch_bounds__\(\d+\)\s+)?\w+" + for line in source.split("\n"): + if annotation in line: + matched = re.findall(pattern, line) + if len(matched) >= 1: + return source.index(matched[0] + "(") + raise ValueError("No global kernel found in the source code") + + +def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int: + # Match decorator followed by function definition across lines + # \s+ allows any whitespace including newlines between decorator and def + pattern = r"@cute\.kernel\s+def\s+(\w+)" + matched = re.search(pattern, source, re.MULTILINE) + if matched: + # Find the position of the opening parenthesis after the function name + # matched.start(1) gives position of function name + func_name_pos = matched.start(1) + # Find the '(' after function name + paren_pos = source.find("(", func_name_pos) + if paren_pos != -1: + return paren_pos + raise ValueError("No global kernel found in the source code") + + +def extract_python_func_declaration(source: str, func_name: str) -> str: + """Extract the full Python function declaration from decorator to colon. + + Args: + source: Source code containing the function + func_name: Name of the function to extract (can include '(' suffix) + + Returns: + The function declaration from 'def' to ':', including parameters + + Example: + For code: + @cute.kernel + def kernel(arg1: cute.Tensor, arg2: int): + ... + Returns: "def kernel(arg1: cute.Tensor, arg2: int)" + """ + # Remove '(' suffix if present + if func_name.endswith("("): + func_name = func_name[:-1] + + # Match from def to the closing ) followed by : + # This handles multi-line function signatures + pattern = rf"def\s+{re.escape(func_name)}\s*\([^)]*\)" + matched = re.search(pattern, source, re.DOTALL) + if matched: + return matched.group(0) + + raise ValueError(f"No function declaration found for {func_name}") + + +def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int: + pattern = r"int32_t\s+\w+" + for line in source.split("\n"): + if annotation in line: + matched = re.findall(pattern, line) + if len(matched) >= 1: + return source.index(matched[0] + "(") + raise ValueError("No global kernel found in the source code") + + +def is_cuda_target(target: Target) -> bool: + return target.kind.name == "cuda" + + +def is_hip_target(target: Target) -> bool: + return target.kind.name == "hip" + + +def is_cpu_target(target: Target) -> bool: + return target.kind.name in ["c"] + + +def is_metal_target(target: Target) -> bool: + return target.kind.name == "metal" + + +def is_cutedsl_target(target: Target) -> bool: + return target.kind.name == "cuda" and "cutedsl" in target.keys + + +def get_annotated_mod( + func_or_mod: tir.PrimFunc | tvm.IRModule, + target: str | Target = "auto", + target_host: str | Target | None = None, + model_type: Literal["device", "host", "all"] = "all", +) -> IRModule | tuple[IRModule, IRModule]: + # Validate model_type early + if model_type not in {"device", "host", "all"}: + raise ValueError(f"Invalid model type: {model_type}") + + # Convert PrimFunc to IRModule if needed + mod = func_or_mod + if isinstance(func_or_mod, tir.PrimFunc): + mod = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + + # Handle target and target_host + if isinstance(target, str): + target = determine_target(target) + target_host = tvm.target.Target.canon_target(canon_target_host(target, target_host)) + target = tvm.target.Target(target, target_host) + + _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) + _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) + + # Apply transformations + mod = LowerAndLegalize(mod, target) + mod = OptimizeForTarget(mod, target) + + # Define dispatch dictionary for different model types + dispatch = { + "device": lambda m: tir.transform.Filter(_is_device_call)(m), + "host": lambda m: tir.transform.Filter(_is_host_call)(m), + "all": lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)(m)), + } + + return dispatch[model_type](mod) + + +def pythonic_expr( + expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False, floor_div_op: str = "/" +) -> str: + """ + Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. + + Args: + expr: The TVM PrimExpr to convert. + dtype_map: A dictionary mapping data types to their string representations. + ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast. + floor_div_op: Operator to use for tvm.tir.FloorDiv. Default '/' preserves prior + behavior (suitable for generating C/C++ expressions). For generating + Python code where integer division is required (e.g. grid/block), + pass '//' explicitly. + Returns: + A string representation of the expression. + """ + if not isinstance(expr, tvm.tir.PrimExpr): + return str(expr) + + # 1. Define operator precedence (higher value means higher precedence) + # Based on Python's operator precedence + PRECEDENCE = { + tvm.tir.Call: 20, # Includes min, max + tvm.tir.Cast: 20, # Treated like a function call + tvm.tir.Mul: 13, + tvm.tir.FloorDiv: 13, + tvm.tir.Div: 13, # For tvm.tir.Div if it appears + tvm.tir.FloorMod: 13, + tvm.tir.Add: 12, + tvm.tir.Sub: 12, + tvm.tir.LT: 10, + tvm.tir.LE: 10, + tvm.tir.GT: 10, + tvm.tir.GE: 10, + tvm.tir.EQ: 10, + tvm.tir.NE: 10, + tvm.tir.And: 5, + tvm.tir.Or: 4, + # Atoms (Var, IntImm) have the highest precedence implicitly + } + # By default, atomic expressions (variables, constants) have the highest precedence + ATOMIC_PRECEDENCE = 100 + + node_to_result_map = {} # Stores (string, precedence) for each node + + def _visitor(node): + # 2. Visitor returns (str, precedence) tuple + if node in node_to_result_map: + return + + if isinstance(node, tvm.tir.Var): + s, p = node.name, ATOMIC_PRECEDENCE + elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)): + s, p = str(node.value), ATOMIC_PRECEDENCE + elif isinstance(node, tvm.tir.Cast): + # C-style cast has high precedence + value_str, _ = node_to_result_map[node.value] + if ignore_cast: + s = value_str + else: + type_str = node.dtype if dtype_map is None else dtype_map[node.dtype] + s = f"({type_str}){value_str}" + p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) + elif isinstance( + node, + ( + tvm.tir.Mul, + tvm.tir.FloorDiv, + tvm.tir.Add, + tvm.tir.Sub, + tvm.tir.FloorMod, + tvm.tir.LT, + tvm.tir.LE, + tvm.tir.GT, + tvm.tir.GE, + tvm.tir.EQ, + tvm.tir.NE, + tvm.tir.And, + tvm.tir.Or, + ), + ): + op_map = { + tvm.tir.Mul: "*", + tvm.tir.FloorDiv: floor_div_op, + tvm.tir.Add: "+", + tvm.tir.Sub: "-", + tvm.tir.FloorMod: "%", + tvm.tir.LT: "<", + tvm.tir.LE: "<=", + tvm.tir.GT: ">", + tvm.tir.GE: ">=", + tvm.tir.EQ: "==", + tvm.tir.NE: "!=", + tvm.tir.And: "and", + tvm.tir.Or: "or", + } + op_str = f" {op_map[type(node)]} " + my_precedence = PRECEDENCE[type(node)] + + a_str, a_precedence = node_to_result_map[node.a] + b_str, b_precedence = node_to_result_map[node.b] + + # 3. Add parentheses intelligently + # Add parentheses if the left operand's precedence is lower than the current operator + if a_precedence < my_precedence: + a_str = f"({a_str})" + # Add parentheses if the right operand's precedence is lower than or equal to the current operator + # 'Equal' is to handle non-associative operations, e.g., a - (b - c) + if b_precedence <= my_precedence: + b_str = f"({b_str})" + + s = f"{a_str}{op_str}{b_str}" + p = my_precedence + elif isinstance(node, (tvm.tir.Min, tvm.tir.Max)): + op_name = "min" if isinstance(node, tvm.tir.Min) else "max" + a_str, _ = node_to_result_map[node.a] + b_str, _ = node_to_result_map[node.b] + s = f"{op_name}({a_str}, {b_str})" + # Function calls have high precedence + p = PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + else: + # Fallback for unhandled expression types + s, p = str(node), 0 + + node_to_result_map[node] = (s, p) + + # Perform post-order traversal + tvm.tir.stmt_functor.post_order_visit(expr, _visitor) + + return next(iter(node_to_result_map[expr]), "") + + +def maybe_desc_name(name: str, matches: list[str], i: int, desc_name_map: dict[str, str] | None = None) -> bool: + """ + Check if a parameter name corresponds to a TMA descriptor. + + Args: + name: The parameter name to check. + matches: List of all matched parameter names. + i: Index of the current match. + desc_name_map: Optional mapping to store descriptor name relationships. + + Returns: + True if the parameter is a TMA descriptor. + """ + match = matches[i] + if not (match == name + "_desc" or match.startswith(name + "_desc_")): + return False + desc_decls = [] + if desc_name_map is not None: + desc_name_map[match] = name + if i > 0: + desc_decls.append(matches[i - 1]) + if i < len(matches) - 1: + desc_decls.append(matches[i + 1]) + return any([decl == "CUtensorMap" for decl in desc_decls]) + + +def parse_function_call_args( + declaration: str, + function_args: list[dict[str, str]], + function_params: list[Any], + desc_name_map: dict[str, str] | None = None, + desc_name_var_map: dict[str, tvm.tir.Var] | None = None, + transform_arg: Callable[[str, str], Any] | None = None, +) -> list[Any]: + """ + Parse function call arguments from a kernel declaration. + + Args: + declaration: The kernel function declaration string. + function_args: List of function argument specifications. + function_params: List of function parameters from TVM IR. + desc_name_map: Optional mapping for descriptor names. + desc_name_var_map: Optional mapping from descriptor names to TVM variables. + transform_arg: Optional function to transform each argument (name, type) -> result. + + Returns: + List of parsed call arguments. + """ + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, declaration) + call_args = [] + + for i, match in enumerate(matches): + for arg in function_args: + if arg["name"] == match: + if transform_arg is not None: + call_args.append(transform_arg(match, arg["type"])) + else: + call_args.append(match) + elif maybe_desc_name(arg["name"], matches, i, desc_name_map): + if transform_arg is not None: + call_args.append(transform_arg(match, "None")) + else: + call_args.append(match) + if desc_name_var_map is not None and function_params is not None: + assert len(call_args) <= len(function_params), f"Too many arguments: {len(call_args)} > {len(function_params)}" + desc_name_var_map[match] = function_params[len(call_args) - 1] + + return call_args + + +class TMADescriptorParams: + """Parsed TMA descriptor parameters.""" + + def __init__(self, handle_name: str, dtype: str, tensor_rank: int, global_address: Any, is_img2col: bool = False): + self.handle_name = handle_name + self.dtype = dtype + self.tensor_rank = tensor_rank + self.global_address = global_address + self.is_img2col = is_img2col + + # Common fields + self.global_dim: list[str] = [] + self.global_stride: list[str] = [] + self.element_strides: list[str] = [] + self.interleave: str = "" + self.swizzle: str = "" + self.l2_promotion: str = "" + self.oob_fill: str = "" + + # Tiled-specific fields + self.box_dim: list[str] = [] + + # Im2col-specific fields + self.lower_corner: list[str] = [] + self.upper_corner: list[str] = [] + self.smem_box_channel: str = "" + self.smem_box_pixel: str = "" + + +def parse_tma_descriptor_args( + tma_descriptor_args: dict[tvm.tir.Var, list[Any]], + desc_name_map: dict[str, str], + desc_name_var_map: dict[str, tvm.tir.Var], + pythonic_expr_func: Callable[[Any], str], +) -> list[TMADescriptorParams]: + """ + Parse TMA descriptor arguments into structured parameters. + + Args: + tma_descriptor_args: Dictionary mapping TMA descriptor variables to their arguments. + desc_name_map: Mapping from descriptor handles to parameter names. + desc_name_var_map: Mapping from descriptor handles to TVM variables. + pythonic_expr_func: Function to convert TVM expressions to strings. + + Returns: + List of parsed TMA descriptor parameters. + """ + if not tma_descriptor_args: + return [] + + results = [] + + for handle_name, _ in desc_name_map.items(): + assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map" + desc_var = desc_name_var_map[handle_name] + + assert desc_var in tma_descriptor_args, f"TMA descriptor {desc_var} not found in {tma_descriptor_args}" + args = tma_descriptor_args[desc_var] + + # Skip __tvm_tensormap_create_tiled and second element (like CUDA version) + if len(args) < 3: + raise ValueError(f"TMA descriptor args too short: {len(args)} elements, expected at least 3") + + tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args + + is_img2col = tma_create_str.value == "__tvm_tensormap_create_im2col" + + # Convert basic fields + dtype = pythonic_expr_func(dtype) + tensor_rank = int(pythonic_expr_func(tensor_rank)) + + # Validate tensor_rank + if not isinstance(tensor_rank, int) or tensor_rank <= 0: + raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") + + params = TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col) + + if not is_img2col: + # Tiled mode + expected_args_len = 4 * tensor_rank + 4 + if len(remaining_args) < expected_args_len: + raise ValueError( + f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}" + ) + + # Extract dimensions and strides + params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] + params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]] + params.box_dim = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]] + params.element_strides = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank]] + + # Extract remaining parameters + try: + interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank : 4 * tensor_rank + 4] + params.interleave = pythonic_expr_func(interleave) + params.swizzle = pythonic_expr_func(swizzle) + params.l2_promotion = pythonic_expr_func(l2_promotion) + params.oob_fill = pythonic_expr_func(oob_fill) + except ValueError as e: + raise ValueError("Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)") from e + else: + # Im2col mode + expected_args_len = 5 * tensor_rank + 2 + if len(remaining_args) < expected_args_len: + raise ValueError( + f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}" + ) + + # Extract dimensions and strides + params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] + params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]] + params.element_strides = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]] + params.lower_corner = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank - 2]] + params.upper_corner = [pythonic_expr_func(i) for i in remaining_args[4 * tensor_rank - 2 : 5 * tensor_rank - 4]] + + # Extract remaining parameters + try: + smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = remaining_args[ + 5 * tensor_rank - 4 : 5 * tensor_rank + 2 + ] + params.smem_box_pixel = pythonic_expr_func(smem_box_pixel) + params.smem_box_channel = pythonic_expr_func(smem_box_channel) + params.interleave = pythonic_expr_func(interleave) + params.swizzle = pythonic_expr_func(swizzle) + params.l2_promotion = pythonic_expr_func(l2_promotion) + params.oob_fill = pythonic_expr_func(oob_fill) + except ValueError as e: + raise ValueError( + "Failed to unpack the final 6 TMA parameters " + "(smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)" + ) from e + + results.append(params) + + return results diff --git a/tilelang/original/tilelang/jit/adapter/wrapper.py b/tilelang/original/tilelang/jit/adapter/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d83d0ccc0689375cecf0ded98f20f0e895fe35a3 --- /dev/null +++ b/tilelang/original/tilelang/jit/adapter/wrapper.py @@ -0,0 +1,912 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from tilelang import tvm as tvm +from typing import Any +from tvm import IRModule +from tvm.target import Target + +from .utils import ( + is_metal_target, + is_cutedsl_target, + match_declare_kernel, + match_declare_kernel_cpu, + is_cuda_target, + is_hip_target, + is_cpu_target, + get_annotated_mod, + pythonic_expr, + parse_function_call_args, + parse_tma_descriptor_args, +) +import re +import logging +import textwrap +from tvm.tir.stmt_functor import post_order_visit + +PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """ + cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1}); + if (result_{0} != cudaSuccess) {{ + snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0})); + return -1; + }} +""" + +PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP = """ + if ({1} > 65536) {{ + snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size for {0} to %d", {1}); + return -1; + }} + return 0; +""" + +PREDEF_INIT_FUNC = """ +#define ERROR_BUF_SIZE 1024 +static char error_buf[ERROR_BUF_SIZE]; + +extern "C" const char* get_last_error() {{ + return error_buf; +}} + +extern "C" int init() {{ + error_buf[0] = '\\0'; + {0} + return 0; +}} +""" + +PREDEF_HOST_FUNC = """ +extern "C" int call({}) {{ +{} +\treturn 0; +}} +""" + +L2_PERSISTENT_MAP_CREATE_HANDLE = """ +\tcudaStreamAttrValue stream_attribute; +\tsize_t init_persisting_l2_cache_size; +\tcudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize); +""" + +L2_PERSISTENT_MAP_INIT_FUNC = """ +\tstream_attribute.accessPolicyWindow.hitRatio = {1}; +\tstream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting; +\tstream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming; +\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {2}); +\tstream_attribute.accessPolicyWindow.base_ptr = (void*)({0}); +\tstream_attribute.accessPolicyWindow.num_bytes = {2}; +\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute); +""" + +L2_PERSISTENT_MAP_RESET_HANDLE = """ +\tstream_attribute.accessPolicyWindow.num_bytes = 0; +\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute); +\tcudaCtxResetPersistingL2Cache(); +\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size); +""" + +TMA_DESC_INIT_FUNC = """ +\tCUtensorMap {0}; +\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1}; +\tcuuint32_t {0}_tensorRank= {2}; +\tvoid *{0}_globalAddress= {3}; +\tcuuint64_t {0}_globalDim[{2}]= {{{4}}}; +\tcuuint64_t {0}_globalStride[{2}]= {{{5}}}; +\tcuuint32_t {0}_boxDim[{2}]= {{{6}}}; +\tcuuint32_t {0}_elementStrides[{2}]= {{{7}}}; +\tCUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){8}; +\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9}; +\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10}; +\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11}; + +\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill); + +\tif ({0}_result != CUDA_SUCCESS) {{ +\t\tstd::stringstream ss; +\t\tss << "Error: Failed to initialize the TMA descriptor {0}"; +\t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str()); +\t\treturn -1; +\t}} +""" + +TMA_IM2COL_DESC_INIT_FUNC = """ +\tCUtensorMap {0}; +\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1}; +\tcuuint32_t {0}_tensorRank= {2}; +\tvoid *{0}_globalAddress= {3}; +\tcuuint64_t {0}_globalDim[{2}]= {{{4}}}; +\tcuuint64_t {0}_globalStride[{2}]= {{{5}}}; +\tcuuint32_t {0}_elementStrides[{2}]= {{{6}}}; +\tint {0}_lowerCorner[{2} - 2]= {{{7}}}; +\tint {0}_upperCorner[{2} - 2]= {{{8}}}; +\tcuuint32_t {0}_channelsPerPixel= {9}; +\tcuuint32_t {0}_pixelsPerColumn= {10}; +\tCUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11}; +\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12}; +\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13}; +\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14}; + +\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( + &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, + {0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill); + +\tif ({0}_result != CUDA_SUCCESS) {{ +\t\tstd::stringstream ss; +\t\tss << "Error: Failed to initialize the TMA descriptor {0}"; +\t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str()); +\t\treturn -1; +\t}} +""" + + +class BaseWrapper(ABC): + @abstractmethod + def wrap(self, *args, **kwargs): + raise NotImplementedError + + +logger = logging.getLogger(__name__) + + +class TLCUDASourceWrapper: + _TYPE_MAP = { + "float32": "float", + "float16": "half_t", + "bfloat16": "bfloat16_t", + "float8_e4m3": "fp8_e4_t", + "float8_e4m3fn": "fp8_e4_t", + "float8_e5m2": "fp8_e5_t", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + "uchar": "uint8_t", + } + + backend = "tl" + device_mod: IRModule | None = None + host_mod: IRModule | None = None + pass_configs: dict[str, Any] | None = None + + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + self.mod = scheduled_ir_module + self.target = target + self.source = source + self.pass_configs = pass_configs + self.device_mod = device_mod + self.host_mod = host_mod + self.function_names: str | None = None + self.dynamic_smem_buf: int | None = None + self.block_info: list[int] | dict = [1, 1, 1] + self.grid_info: list[int] | dict = [1, 1, 1] + self.tma_descriptor_args: dict | None = None + self.l2_persistent_map: dict[str, dict] | None = {} + self.parse_source_information() + self.srcpath: str | None = None + self.libpath: str | None = None + self.lib_code: str | None = self.update_lib_code(source) + + def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + # This wrapper generates C/CUDA source. C/C++ integer division uses '/', + # and '//' is not a valid operator in C/C++. + return pythonic_expr(expr, self._TYPE_MAP, floor_div_op="/") + + def _lookup_type(self, dtype: str | Any) -> str: + key = dtype if isinstance(dtype, str) else str(dtype) + result = self._TYPE_MAP.get(key) + assert result is not None, f"Unsupported dtype {dtype}" + return result + + def is_tma_descriptor_arg(self, arg_name: str) -> bool: + return arg_name in self.prim_func.buffer_map + + def create_dispatch_func(self, code, function_informations): + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + function_args = [] + + # Collect function arguments based on primary function's parameters and buffer mappings + # QA(@lei): Why not use device_mod.params? + # device func lack buffer map (to convert buffer handle to buffer) + for param in self.prim_func.params: + if param in self.prim_func.buffer_map: + buffer = self.prim_func.buffer_map[param] + function_args.append( + { + "name": buffer.data.name, + "type": self._lookup_type(buffer.dtype) + "* __restrict__", + } + ) + elif isinstance(param, tvm.tir.Var): + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) + else: + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") + # Add dynamic symbols as integer arguments + for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: + if dyn_sym not in [arg["name"] for arg in function_args]: + function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) + + function_args.append(self.get_stream_type()) + + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + has_l2_persistent_map = False + for function_name, _ in function_informations.items(): + if function_name in self.l2_persistent_map: + has_l2_persistent_map = True + break + + kernel_launch_code = """""" + if has_l2_persistent_map: + kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE + desc_name_map: dict[str, str] = {} + desc_name_var_map: dict[str, tvm.tir.Var] = {} + for function_name, function_info in function_informations.items(): + block_info = function_info["block_info"] + grid_info = function_info["grid_info"] + dynamic_smem_buf = function_info["dynamic_smem_buf"] + function_params = function_info["function_params"] + + # Find the location of the global kernel function in the code + index = match_declare_kernel(code, function_name + "(") + + # Analyze the function declaration to prepare for argument extraction + declaration = code[index:].split(";")[0] + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + + block_str = ( + f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" + ) + grid_str = ( + f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" + ) + smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf + init_l2_persistent_map = self.generate_l2_persistent_map(function_name) + kernel_launch_code += init_l2_persistent_map + + if self.use_cooperative_groups[function_name]: + args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map) + assert len(function_params) == len(args_list), ( + f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + ) + args_array = [f"(void*)&{arg}" for arg in args_list] + call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n" + kernel_launch_code += call_args + # Using cudaLaunchCooperativeKernel to launch the kernel + kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( + function_name, grid_str, block_str, function_name + "_args", smem_str + ) + else: + args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map) + assert len(function_params) == len(args_list), ( + f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + ) + call_args = ", ".join(args_list) + kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n" + kernel_launch_code += f'\tTILELANG_CHECK_LAST_ERROR("{function_name}");\n' + if has_l2_persistent_map: + kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE + + init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) + kernel_launch_code = init_tma_descriptor_args + kernel_launch_code + + # Wrap the kernel dispatch logic in an external C function + host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code) + return host_func + + def generate_l2_persistent_map(self, function_name: str) -> str: + if function_name not in self.l2_persistent_map: + return "" + init_l2_persistent_map = "" + for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items(): + # get persisting_l2_cache_max_size + from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() + try: + num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) + except Exception: + # as size_in_bytes maybe a symbolic expression + num_bytes = persisting_l2_cache_max_size + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + + return init_l2_persistent_map + + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + tma_descriptor_init = "" + if self.tma_descriptor_args is None: + return tma_descriptor_init + + # Parse TMA descriptor arguments using the common utility + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) + + # Generate C++ code from parsed parameters + for params in parsed_params: + if not params.is_img2col: + tma_descriptor_init += TMA_DESC_INIT_FUNC.format( + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ",".join(params.global_dim), + ",".join(params.global_stride), + ",".join(params.box_dim), + ",".join(params.element_strides), + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) + else: + tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC.format( + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ",".join(params.global_dim), + ",".join(params.global_stride), + ",".join(params.element_strides), + ",".join(params.lower_corner), + ",".join(params.upper_corner), + params.smem_box_channel, + params.smem_box_pixel, + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) + + return tma_descriptor_init + + def parse_source_information(self): + if self.device_mod is None or self.host_mod is None: + with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): + device_mod, host_mod = get_annotated_mod(self.mod, self.target) + self.device_mod = device_mod + self.host_mod = host_mod + assert len(self.device_mod.functions) >= 1, "Device module should have at least one function." + assert len(self.host_mod.functions) == 1, "Only support one function in host module." + + block_info_map = {} + grid_info_map = {} + dynamic_smem_buf_map = {} + function_names = [] + use_cooperative_groups_map = {} + for g_var, func in self.device_mod.functions.items(): + # Default block and grid configurations + block_info = [1, 1, 1] + grid_info = [1, 1, 1] + function_name = g_var.name_hint + attrs = func.attrs + dynamic_smem_buf = None + use_cooperative_groups = False + if "use_cooperative_groups" in attrs: + use_cooperative_groups = attrs["use_cooperative_groups"] + if "dyn_shared_memory_buf" in attrs: + dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + # Extract block and grid sizes from thread extents + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + grid_info["xyz".index(tag[-1])] = extent + # Map the extracted configurations to each function + block_info_map[function_name] = block_info + grid_info_map[function_name] = grid_info + dynamic_smem_buf_map[function_name] = dynamic_smem_buf + use_cooperative_groups_map[function_name] = use_cooperative_groups + function_names.append(function_name) + + # Store the mappings for use in code generation + self.block_info = block_info_map + self.grid_info = grid_info_map + self.dynamic_smem_buf = dynamic_smem_buf_map + self.use_cooperative_groups = use_cooperative_groups_map + + function_names_index = {} + for _, func in self.host_mod.functions.items(): + if "tma_descriptor_args" in func.attrs: + self.tma_descriptor_args = func.attrs["tma_descriptor_args"] + if "l2_persistent_map" in func.attrs: + self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"] + + host_code = str(func) + for function_name in function_names: + index = host_code.index(f'T.call_packed("{function_name}"') + function_names_index[function_name] = index + # sort function_names + function_names = sorted(function_names, key=lambda x: function_names_index[x]) + self.function_names = function_names + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set: dict[str, str] = {} + + def unique_push_back(name: str, dtype: str): + if name not in dynamic_symbolic_set: + dynamic_symbolic_set[name] = dtype + else: + assert dtype == dynamic_symbolic_set[name] + + for param in prim_func.params: + if param in prim_func.buffer_map: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var): + unique_push_back(dim.name, str(dim.dtype)) + + # Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape. + for param in prim_func.params: + if param in prim_func.buffer_map: + buffer = prim_func.buffer_map[param] + for stride in buffer.strides: + if isinstance(stride, tvm.tir.Var): + unique_push_back(stride.name, str(stride.dtype)) + + return list(dynamic_symbolic_set.items()) + + def get_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): + if dynamic_smem_buf is not None: + # Format the cudaFuncSetAttribute call for dynamic shared memory + call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf) + # Format the initialization function using the call_str + init_funcs = PREDEF_INIT_FUNC.format(call_str) + return init_funcs + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Get the function names + function_names = self.function_names + # Get the CUDA initialization function + init_func = self.get_init_func() + + # Organize function information for code generation + function_informations = {} + for function_name in function_names: + # Do not update function with dispatch host function + if (function_name not in self.block_info) or (function_name not in self.grid_info): + continue + assert function_name in self.device_mod, f"Function {function_name} not found in device module" + device_func = self.device_mod[function_name] + kernel_params_cnt = len(device_func.params) + function_params: list[str] = None + + def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): + nonlocal function_params + if isinstance(node, tvm.tir.Call): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + return + args = node.args + if not args or args[0] != fn: + return + if len(args) < 1 + param_cnt: + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] + + post_order_visit(self.host_func.body, visitor) + assert function_params is not None, "function_params should not be None" + + function_informations[function_name] = { + "function_name": function_name, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + "function_params": function_params, + } + + # Create the host function wrapper for the CUDA kernel + host_func = self.create_dispatch_func(code, function_informations) + # Combine the source, initialization function, and host function to form the complete library code + lib_code = self.source + init_func + host_func + return lib_code + + def get_stream_type(self) -> dict[str, str]: + return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"} + + @property + def prim_func(self): + if len(self.mod.get_global_vars()) == 1: + return self.mod[self.mod.get_global_vars()[0]] + elif "main" in self.mod: + return self.mod["main"] + else: + for _, function in self.mod.functions_items(): + attr = function.attrs + if "tir.is_global_func" in attr and attr["tir.is_global_func"]: + return function + raise ValueError("Cannot find primary function in the module.") + + @property + def device_func(self): + if len(self.device_mod.get_global_vars()) == 1: + return self.device_mod[self.device_mod.get_global_vars()[0]] + elif "main" in self.device_mod: + return self.device_mod["main"] + else: + for _, function in self.device_mod.functions.items(): + attr = function.attrs + if "tir.is_global_func" in attr and attr["tir.is_global_func"]: + return function + raise ValueError("Cannot find primary function in the module.") + + @property + def host_func(self): + if len(self.host_mod.get_global_vars()) == 1: + return self.host_mod[self.host_mod.get_global_vars()[0]] + elif "main" in self.host_mod: + return self.host_mod["main"] + else: + for _, function in self.host_mod.functions.items(): + attr = function.attrs + if "tir.is_global_func" in attr and attr["tir.is_global_func"]: + return function + raise ValueError("Cannot find primary function in the module.") + + +class TLHIPSourceWrapper(TLCUDASourceWrapper): + """ + A wrapper class for the TileLang HIP backend. + """ + + _TYPE_MAP = { + "float32": "float", + "float16": "half_t", + "bfloat16": "bfloat16_t", + "float8_e4m3": "fp8_e4_t", + "float8_e4m3fn": "fp8_e4_t", + "float8_e5m2": "fp8_e5_t", + "float8_e4m3fnuz": "fp8_e4_t", + "e4m3fnuz_float8": "fp8_e4_t", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + "uchar": "uint8_t", + } + + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) + + def get_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): + if dynamic_smem_buf is not None: + # Format the cudaFuncSetAttribute call for dynamic shared memory + call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(function_name, dynamic_smem_buf) + # Format the initialization function using the call_str + init_funcs = PREDEF_INIT_FUNC.format(call_str) + return init_funcs + + def get_stream_type(self) -> dict[str, str]: + return {"name": "stream=hipStreamDefault", "type": "hipStream_t"} + + +class TLCPUSourceWrapper: + _TYPE_MAP = { + "float32": "float", + "float16": "half", + "int32": "int32_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + "int64": "int64_t", + "uint64": "uint64_t", + "float64": "double", + "bool": "bool", + "uchar": "uchar", + } + + # Use common init with error buffer and get_last_error for CPU backend as well + INIT_FUNC = PREDEF_INIT_FUNC.format("") + + CALL_PREFIX = textwrap.dedent(""" + #ifdef __cplusplus + extern "C" + #endif + int32_t call({}) {{ + return {}; + }} + """) + + backend = "tl" + device_mod: IRModule | None = None + host_mod: IRModule | None = None + pass_configs: dict[str, Any] | None = None + + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + self.mod = scheduled_ir_module + self.target = target + self.source = source + self.device_mod = device_mod + self.host_mod = host_mod + self.pass_configs = pass_configs + self.function_names: str | None = None + self.dynamic_smem_buf: int | None = None + self.parse_source_information() + self.srcpath: str | None = None + self.libpath: str | None = None + self.lib_code: str | None = self.update_lib_code(source) + + def _lookup_type(self, dtype: str | Any) -> str: + key = dtype if isinstance(dtype, str) else str(dtype) + result = self._TYPE_MAP.get(key) + assert result is not None, f"Unsupported dtype {dtype}" + return result + + def create_call_func(self, code, function_informations): + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + function_args = [] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + if param in self.prim_func.buffer_map: + buffer = self.prim_func.buffer_map[param] + function_args.append( + { + "name": buffer.name, + "type": self._lookup_type(buffer.dtype) + "*", + } + ) + elif isinstance(param, tvm.tir.Var): + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) + else: + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") + # Add dynamic symbols as integer arguments + for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s, function_args): + pattern = r"[,\s]*(?:\w+\s*\*+\s*\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + _call_str = """""" + + for function_name, _ in function_informations.items(): + # Find the location of the global kernel function in the code + index = match_declare_kernel_cpu(code, function_name + "(") + + # Analyze the function declaration to prepare for argument extraction + declaration = code[index:].split(";")[0] + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + + call_args = ", ".join(func_call_args(declaration, function_args)) + _call_str += f"{function_name}({call_args})" + + # Wrap the kernel dispatch logic in an external C function + host_func = self.CALL_PREFIX.format(def_args, _call_str) + return host_func + + def parse_source_information(self): + with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): + device_mod, host_mod = get_annotated_mod(self.mod, self.target) + assert len(device_mod.functions) >= 1, "Device module should have at least one function." + assert len(host_mod.functions) == 1, "Only support one function in host module." + + function_names = [] + for g_var, _ in device_mod.functions.items(): + function_name = g_var.name_hint + function_names.append(function_name) + + self.function_names = function_names + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set: dict[str, str] = {} + for param in prim_func.params: + if param in prim_func.buffer_map: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set): + dynamic_symbolic_set[dim.name] = str(dim.dtype) + return list(dynamic_symbolic_set.items()) + + def get_cpu_init_func(self): + # Provide init() and get_last_error() for CPU backend + return self.INIT_FUNC + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Get the function names + function_names = self.function_names + # Get the CPU initialization function + init_func = self.get_cpu_init_func() + + # Organize function information for code generation + function_informations = {} + for function_name in function_names: + function_informations[function_name] = { + "function_name": function_name, + } + + # Create the call function wrapper for the CPU kernel + call_func = self.create_call_func(code, function_informations) + # Combine the source, initialization function, and call function to form the complete library code + lib_code = self.source + init_func + call_func + return lib_code + + @property + def prim_func(self): + if len(self.mod.get_global_vars()) == 1: + return self.mod[self.mod.get_global_vars()[0]] + elif "main" in self.mod: + return self.mod["main"] + else: + for _, function in self.mod.functions_items(): + attr = function.attrs + if "tir.is_global_func" in attr and attr["tir.is_global_func"]: + return function + raise ValueError("Cannot find primary function in the module.") + + +class TLMetalSourceWrapper: + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + self.mod = scheduled_ir_module + self.target = target + self.source = source + self.pass_configs = pass_configs + self.device_mod = device_mod + self.host_mod = host_mod + self.lib_code = self.update_lib_code(source) + + def update_lib_code(self, code: str): + self.lib_code = code + return self.lib_code + + +# TLCuTeDSLSourceWrapper has been moved to tilelang.jit.adapter.cutedsl.wrapper + + +class TLWrapper(BaseWrapper): + """ + A wrapper class for the TileLang backend. + """ + + device_mod: IRModule | None = None + host_mod: IRModule | None = None + pass_configs: dict[str, Any] | None = None + target: Target | None = None + lib: object | None = None + + def __init__(self, target: Target): + super().__init__() + self.scheduled_ir_module = None + self.pass_configs = None + self.target = target + self.lib = None + + def assign_optimized_module(self, scheduled_ir_module: IRModule): + self.scheduled_ir_module = scheduled_ir_module + + def assign_pass_configs(self, pass_configs: dict[str, Any]): + self.pass_configs = pass_configs + + def assign_host_module(self, host_mod: IRModule): + self.host_mod = host_mod + + def assign_device_module(self, device_mod: IRModule): + self.device_mod = device_mod + + # Get Scheduled Rt Module and return source to be compiled + def wrap(self, c_source: str): + assert self.scheduled_ir_module is not None, "Please assign optimized module first." + if is_cuda_target(self.target): + wrapper_class = TLCUDASourceWrapper + elif is_hip_target(self.target): + wrapper_class = TLHIPSourceWrapper + elif is_cpu_target(self.target): + wrapper_class = TLCPUSourceWrapper + elif is_metal_target(self.target): + wrapper_class = TLMetalSourceWrapper + else: + raise ValueError(f"Unsupported platform: {self.arch.platform}") + wrapper = wrapper_class( + scheduled_ir_module=self.scheduled_ir_module, + source=c_source, + target=self.target, + device_mod=self.device_mod, + host_mod=self.host_mod, + pass_configs=self.pass_configs, + ) + return wrapper.lib_code + + +class TLPyWrapper(TLWrapper): + def __init__(self, target: Target): + super().__init__(target) + + def wrap(self, py_source: str): + # assert self.scheduled_ir_module is not None, "Please assign optimized module first." + if is_cutedsl_target(self.target): + from tilelang.jit.adapter.cutedsl import TLCuTeDSLSourceWrapper + + wrapper_class = TLCuTeDSLSourceWrapper + elif is_cuda_target(self.target): + from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper + + wrapper_class = TLNVRTCSourceWrapper + else: + raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") + wrapper = wrapper_class( + scheduled_ir_module=self.scheduled_ir_module, + source=py_source, + target=self.target, + device_mod=self.device_mod, + host_mod=self.host_mod, + pass_configs=self.pass_configs, + ) + return { + "host_func": getattr(wrapper, "host_func", None), + "function_names": getattr(wrapper, "function_names", None), + "tma_cpp_init_code": getattr(wrapper, "tma_cpp_init_code", None), + "tma_lib_name": getattr(wrapper, "tma_lib_name", None), + "launcher_cpp_code": getattr(wrapper, "launcher_cpp_code", None), + "launcher_lib_name": getattr(wrapper, "launcher_lib_name", None), + } diff --git a/tilelang/original/tilelang/jit/env.py b/tilelang/original/tilelang/jit/env.py new file mode 100644 index 0000000000000000000000000000000000000000..6af7adc750f59fa3b338b17b2ba2dc69e13ef75d --- /dev/null +++ b/tilelang/original/tilelang/jit/env.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the flashinfer project +# (https://github.com/flashinfer-ai/flashinfer). +"""Library information. This is a standalone file that can be used to get various info. +Modified from flashinfer +""" + +import pathlib + +from tilelang.env import ( + CUTLASS_INCLUDE_DIR, # noqa: F401 + TILELANG_TEMPLATE_PATH, # noqa: F401 +) + + +def _get_workspace_dir_name() -> pathlib.Path: + try: + from tilelang.contrib import nvcc + from tilelang.utils.target import determine_target + + target = determine_target(return_object=True) + # create tmp source file for torch cpp extension + arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target)) + except Exception: + arch = "noarch" + # e.g.: $HOME/.cache/tilelang/75_80_89_90/ + return pathlib.Path.home() / ".cache" / "tilelang" / arch + + +TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name() +TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops" +TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated" diff --git a/tilelang/original/tilelang/jit/execution_backend.py b/tilelang/original/tilelang/jit/execution_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..db5e4a8b41d58695c7325aad7c35b0117f7f2a24 --- /dev/null +++ b/tilelang/original/tilelang/jit/execution_backend.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from collections.abc import Iterable + +from tvm.target import Target +from tilelang.jit.adapter.utils import is_cutedsl_target +from tilelang.env import env as _env + +# Canonical names for execution backends used internally +_CANONICAL_MAP = { + "dlpack": "tvm_ffi", # historical alias +} + + +def _canon_backend(name: str | None) -> str | None: + if name is None: + return None + key = str(name).lower() + return _CANONICAL_MAP.get(key, key) + + +def _target_kind(target: Target) -> str: + # tvm.target.Target always has kind.name + return target.kind.name + + +def allowed_backends_for_target(target: Target, *, include_unavailable: bool = True) -> list[str]: + """Return allowed execution backends for a given TVM target kind. + + include_unavailable: if False, this will filter out backends that are known + to be unavailable at runtime (e.g., NVRTC without cuda-python installed). + """ + kind = _target_kind(target) + + if is_cutedsl_target(target): + return ["cutedsl"] + elif kind == "cuda": + allowed = ["tvm_ffi", "nvrtc", "cython", "ctypes"] + elif kind == "hip": + allowed = ["tvm_ffi", "cython", "ctypes"] + elif kind == "metal": + allowed = ["torch"] + elif kind == "c": # CPU C backend + allowed = ["cython", "ctypes", "tvm_ffi"] + else: + # Fallback: prefer portable hosts + allowed = ["cython", "ctypes", "tvm_ffi"] + + if not include_unavailable: + # Drop NVRTC if not importable + try: + from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy + + if not is_nvrtc_available and "nvrtc" in allowed: + allowed = [b for b in allowed if b != "nvrtc"] + except Exception: + # Be conservative and keep nvrtc if detection itself fails + pass + + return allowed + + +def _format_options(options: Iterable[str]) -> str: + return ", ".join(sorted(options)) + + +def resolve_execution_backend(requested: str | None, target: Target) -> str: + """Resolve an execution backend string to a concrete backend. + + - Supports the alias "dlpack" -> "tvm_ffi". + - Supports the sentinel "auto" which selects a sensible default per target. + - Validates the combination (target, backend) and raises with helpful + alternatives when invalid. + """ + req = _canon_backend(requested) + allowed_all = allowed_backends_for_target(target, include_unavailable=True) + allowed_avail = allowed_backends_for_target(target, include_unavailable=False) + + def _require_gemm_v1_for_cutedsl(): + if not _env.use_gemm_v1(): + raise ValueError( + "CuTeDSL backend requires GEMM v1. Please set environment variable TILELANG_USE_GEMM_V1=1 before importing tilelang." + ) + # Fail fast with a clear error if CuTeDSL dependencies are missing or incompatible. + try: + from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy + + check_cutedsl_available() + except ImportError as e: + # Keep resolve_execution_backend's error semantics (ValueError) while + # preserving the actionable ImportError message. + raise ValueError(str(e)) from e + + # Default selection for auto/None + if req in (None, "auto"): + if is_cutedsl_target(target): + _require_gemm_v1_for_cutedsl() + return "cutedsl" + kind = _target_kind(target) + if kind == "cuda": + choice = "tvm_ffi" + elif kind == "metal": + choice = "torch" + else: + choice = "cython" + # If the chosen default is not available (very rare), fall back to first available + if choice not in allowed_avail and allowed_avail: + choice = allowed_avail[0] + return choice + + # Validate against allowed + if req not in allowed_all: + raise ValueError( + f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. " + f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'." + ) + + # Promote to availability-aware set for nicer errors (e.g., nvrtc not installed) + if req not in allowed_avail: + raise ValueError( + f"Execution backend '{requested}' requires extra dependencies and is not available now. " + f"Try one of: {_format_options(allowed_avail)}." + ) + + # CuTeDSL requires GEMM v1 + if req == "cutedsl": + _require_gemm_v1_for_cutedsl() + + return req diff --git a/tilelang/original/tilelang/jit/kernel.py b/tilelang/original/tilelang/jit/kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..a788e76ba77473fd52bff0ea6cd9d957ebe3a07f --- /dev/null +++ b/tilelang/original/tilelang/jit/kernel.py @@ -0,0 +1,794 @@ +from __future__ import annotations +from typing import Any, Callable, Generic, Literal, TypeVar + +# Python 3.9 compatibility for ParamSpec +try: + from typing import ParamSpec +except ImportError: # Python < 3.10 + from typing_extensions import ParamSpec + +from tilelang.jit.adapter.utils import is_cutedsl_target, is_metal_target, is_cuda_target +from tvm.target import Target +from tvm.tir import PrimFunc + +import tilelang +from tilelang import tvm +from tilelang import env +from tilelang.engine.param import CompiledArtifact, KernelParam +from tilelang.jit.adapter import ( + BaseKernelAdapter, + CtypesKernelAdapter, + CythonKernelAdapter, + CuTeDSLKernelAdapter, + TVMFFIKernelAdapter, + MetalKernelAdapter, +) +from tilelang.profiler import Profiler, TensorSupplyType +from tilelang.utils.target import determine_target +from tilelang.contrib import nvcc as tl_nvcc +from tilelang.transform import PassConfigKey +import logging +import os + +logger = logging.getLogger(__name__) + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class JITKernel(Generic[_P, _T]): + """ + A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. + + Attributes + ---------- + artifact : CompiledArtifact + The compiled artifact containing the runtime module and parameters. + adapter : BaseKernelAdapter + The adapter for the compiled function. + torch_function : Callable + The compiled function that can be invoked as a PyTorch-compatible function. + """ + + prim_func: PrimFunc = None + artifact: CompiledArtifact = None + adapter: BaseKernelAdapter = None + torch_function: Callable = None + + # tuner result + latency: float = None + config: dict[str, Any] = None + ref_latency: float = None + + def __init__( + self, + func: PrimFunc = None, + out_idx: list[int] | int = None, + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", + target: str | Target = "auto", + target_host: str | Target = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + from_database: bool = False, + compile_flags: list[str] | None = None, + ): + """ + Initializes a TorchFunction instance. + + Parameters + ---------- + func : tvm.tir.PrimFunc, optional + The TileLang TIR function to compile and wrap. + out_idx : Union[List[int], int], optional + Index(es) of the output tensors to return (default: None). + execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional + Execution backend to use for kernel execution. + target : Union[str, Target], optional + Compilation target, either as a string or a TVM Target object (default: "auto"). + target_host : Union[str, Target], optional + Target host for cross-compilation (default: None). + verbose : bool, optional + Whether to enable verbose output (default: False). + pass_configs : dict, optional + Additional keyword arguments to pass to the Compiler PassContext. + Refer to `tilelang.PassConfigKey` for supported options. + from_database : bool, optional + Whether to create a TorchFunction from a database. + """ + self.prim_func = func + self.execution_backend = execution_backend + self.target_host = target_host + self.verbose = verbose + + if pass_configs is None: + pass_configs = {} + self.pass_configs = pass_configs + + self.compile_flags = [compile_flags] if isinstance(compile_flags, str) else compile_flags + + # Ensure the target is always a valid TVM Target object. + self.target = determine_target(target, return_object=True) + + # Validate the execution backend. + assert execution_backend in [ + "tvm_ffi", + "ctypes", + "cython", + "nvrtc", + "torch", + "cutedsl", + ], f"Invalid execution backend. {execution_backend}" + if execution_backend == "cython": + from tilelang.contrib.cc import get_cplus_compiler + + assert get_cplus_compiler() is not None, "Cython backend requires a C++ compiler, please install or use other backends." + + if from_database: + return + + # Print log on compilation starts + # NOTE(Chenggang): printing could let the training/inference framework easier to know + # whether the communication timeout is from compilation + if env.is_print_on_compilation_enabled(): + # assert func must have "global_symbol" + func_name = func.attrs.get("global_symbol") + assert func_name is not None, "func must have global_symbol" + logger.info(f"TileLang begins to compile kernel `{func_name}` with `{out_idx=}`") + + # Compile the TileLang function and create a kernel adapter for execution. + adapter = self._compile_and_create_adapter(func, out_idx) + + if env.is_print_on_compilation_enabled(): + func_name = func.attrs.get("global_symbol") + assert func_name is not None, "func must have global_symbol" + logger.info(f"TileLang completes to compile kernel `{func_name}`") + + # The adapter's function is assigned as the callable function for this instance. + self.adapter = adapter + self.torch_function = adapter.func + + @classmethod + def from_database( + cls, + func: PrimFunc, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + params: list[KernelParam], + target: str | Target, + target_host: str | Target, + out_idx: list[int] | int, + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + """ + Alternative constructor to create a TorchFunction directly from a database. + """ + instance = cls( + func=func, + out_idx=out_idx, + execution_backend=execution_backend, + target=target, + target_host=target_host, + pass_configs=pass_configs, + from_database=True, + compile_flags=compile_flags, + ) + + instance.adapter = instance._create_adapter_from_database( + func_or_mod=func, + params=params, + result_idx=out_idx, + target=target, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + instance.torch_function = instance.adapter.func + return instance + + def __call__(self, *args: _P.args, **kwds: _P.kwargs) -> _T: + """ + Invokes the compiled function with the given arguments. + + Parameters + ---------- + *args : Any + Positional arguments for the function. + **kwds : Any + Keyword arguments for the function. + + Returns + ------- + Any + The result of the function execution. + """ + return self.torch_function(*args, **kwds) + + def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int]) -> BaseKernelAdapter: + """ + Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. + + Parameters + ---------- + tilelang_func : tvm.tir.PrimFunc + The TileLang (TVM TIR) function to compile. + + Returns + ------- + BaseKernelAdapter + The compiled and ready-to-run kernel adapter. + """ + verbose = self.verbose + target = self.target + target_host = self.target_host + + execution_backend = self.execution_backend + pass_configs = self.pass_configs or {} + + compile_flags = self.compile_flags + + if compile_flags is not None: + compile_flags_cfg = pass_configs.get(PassConfigKey.TL_DEVICE_COMPILE_FLAGS) + pass_configs[PassConfigKey.TL_DEVICE_COMPILE_FLAGS] = ( + compile_flags_cfg + compile_flags if compile_flags_cfg is not None else compile_flags + ) + + # Compile the function with TVM, optimizing with shared memory lowering. + enable_host_codegen = execution_backend == "tvm_ffi" + enable_device_compile = execution_backend == "tvm_ffi" + with tvm.transform.PassContext(opt_level=3, config=pass_configs), self.target: + artifact = tilelang.lower( + tilelang_func, + target=target, + target_host=target_host, + enable_host_codegen=enable_host_codegen, + enable_device_compile=enable_device_compile, + ) + + self.artifact = artifact + + # Create an adapter based on the specified execution backend. + if execution_backend == "tvm_ffi": + # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. + # But we need to ensure that the runtime is enabled and the runtime module is not None. + assert artifact.rt_mod is not None, "tvm_ffi backend requires a runtime module." + adapter = TVMFFIKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + rt_mod=artifact.rt_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "ctypes": + adapter = CtypesKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "cython": + adapter = CythonKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "nvrtc": + from tilelang.jit.adapter import NVRTCKernelAdapter + + adapter = NVRTCKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "torch": + assert is_metal_target(target) + adapter = MetalKernelAdapter( + params=artifact.params, + result_idx=out_idx, + # target=target, + func_or_mod=tilelang_func, + # host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + kernel_global_source=artifact.kernel_source, + verbose=verbose, + # pass_configs=pass_configs, + # compile_flags=compile_flags, + ) + elif execution_backend == "cutedsl": + assert is_cutedsl_target(target) + adapter = CuTeDSLKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + else: + # Handle invalid backend. + raise ValueError(f"Invalid execution backend: {execution_backend}") + + return adapter + + def _create_adapter_from_database( + self, + params: list[KernelParam], + result_idx: list[int] | int, + target: str | Target, + func_or_mod: PrimFunc | tvm.runtime.Module, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ) -> BaseKernelAdapter: + target = self.target + execution_backend = self.execution_backend + + # Create an adapter based on the specified execution backend. + if execution_backend == "tvm_ffi": + adapter = TVMFFIKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "ctypes": + adapter = CtypesKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "cython": + adapter = CythonKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + ) + elif execution_backend == "nvrtc": + from tilelang.jit.adapter import NVRTCKernelAdapter + + adapter = NVRTCKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + elif execution_backend == "cutedsl": + adapter = CuTeDSLKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + else: + # Handle invalid backend. + raise ValueError(f"Invalid execution backend: {execution_backend}") + + return adapter + + @classmethod + def from_tilelang_function(cls, tilelang_func: PrimFunc, **kwargs): + """ + Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc. + + Parameters + ---------- + tilelang_func : tvm.tir.PrimFunc + The TileLang (TVM TIR) function to compile. + **kwargs : dict + Additional keyword arguments to pass to the constructor. + + Returns + ------- + TorchFunction + An instance of TorchFunction wrapping the compiled function. + """ + return cls(func=tilelang_func, **kwargs) + + def get_profiler(self, tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler: + """ + Creates a profiler to benchmark the compiled runtime module. + + Parameters + ---------- + tensor_supply_type : TensorSupplyType, optional + The type of input tensors to supply for profiling (default: TensorSupplyType.Auto). + + Returns + ------- + Profiler + A Profiler instance for benchmarking the runtime module. + """ + return Profiler(self.params, self.out_idx, tensor_supply_type).with_default_adapter(self.adapter) + + def get_kernel_source(self, kernel_only: bool = True) -> str: + """ + Returns the source code of the compiled kernel function. + + Returns + ------- + str + The source code of the compiled kernel function. + """ + if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi", "cutedsl"}: + return self.adapter.get_kernel_source(kernel_only=kernel_only) + return self.artifact.kernel_source + + def get_host_source(self) -> str: + """ + Returns the source code of the host function. + """ + if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi", "cutedsl"}: + return self.adapter.get_host_source() + assert self.artifact.host_mod is not None, "host_mod is not available" + return str(self.artifact.host_mod) + + def run_once(self, func: Callable | None = None) -> None: + return self.get_profiler().run_once(func) + + def show_source(self, which: Literal["kernel", "host", "both"] = "kernel") -> None: + """ + Print generated source code to stdout. + + Parameters + ---------- + which : Literal["kernel", "host", "both"], optional + Select which source to print. Defaults to "kernel". + + Examples + -------- + >>> jit_kernel.show_source() # print kernel source + >>> jit_kernel.show_source("host") # print host source + >>> jit_kernel.show_source("both") # print both sources + """ + try: + if which == "kernel": + src = self.get_kernel_source() + print(src) + elif which == "host": + src = self.get_host_source() + # Host is generally C/C++ + print(src) + elif which == "both": + print("===== Kernel Source =====") + ksrc = self.get_kernel_source() + print(ksrc) + print("===== Host Source =====") + hsrc = self.get_host_source() + print(hsrc) + else: + raise ValueError(f"Unknown option for 'which': {which}") + except Exception as e: + logger.error(f"Failed to show source code: {e}") + + def export_sources(self, kernel_path: str | None = None, host_path: str | None = None) -> None: + """ + Export generated source code to files. + + Parameters + ---------- + kernel_path : Optional[str] + Destination file path to write the kernel source. If None, skips writing kernel code. + host_path : Optional[str] + Destination file path to write the host source. If None, skips writing host code. + + Examples + -------- + >>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu") + >>> jit_kernel.export_sources(host_path="/tmp/host.cc") + >>> jit_kernel.export_sources( + ... kernel_path="/tmp/kernel.cu", + ... host_path="/tmp/host.cc", + ... ) + """ + if kernel_path is None and host_path is None: + raise ValueError("At least one of kernel_path or host_path must be provided.") + try: + if kernel_path is not None: + dir_path = os.path.dirname(kernel_path) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + with open(kernel_path, "w") as f: + f.write(self.get_kernel_source()) + if host_path is not None: + dir_path = os.path.dirname(host_path) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + with open(host_path, "w") as f: + f.write(self.get_host_source()) + except Exception as e: + logger.error(f"Failed to export sources: {e}") + + # Backward compatibility alias (deprecated) + def print_source_code(self, which: Literal["kernel", "host", "both"] = "kernel", file: str | None = None) -> None: + """ + Deprecated: use show_source() or export_sources() instead. + + Parameters + ---------- + which : Literal["kernel", "host", "both"], optional + Kept for backward compatibility with printing behavior. + file : Optional[str] + If provided, behaves like export_sources(kernel_path=file). + + Examples + -------- + >>> # New API (preferred) + >>> jit_kernel.show_source("both") + >>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu") + + >>> # Old API (still works but deprecated) + >>> jit_kernel.print_source_code(file="/tmp/kernel.cu") + """ + logger.warning("print_source_code is deprecated; use show_source() or export_sources() instead.") + if file is not None: + # Historical behavior wrote only kernel source when file provided + self.export_sources(kernel_path=file) + else: + self.show_source(which=which) + + def update_tuner_result(self, latency: float, config: dict[str, Any], ref_latency: float) -> JITKernel: + """ + Updates the tuning results for this kernel. + + Parameters + ---------- + latency : float + The measured latency of this kernel configuration. + config : Dict[str, Any] + The configuration parameters used for this kernel. + ref_latency : float + The reference latency to compare against. + + Returns + ------- + None + """ + self.latency = latency + self.config = config + self.ref_latency = ref_latency + + return self + + def get_tuner_result(self) -> dict[str, Any]: + """ + Gets the tuning results for this kernel. + + Returns + ------- + Dict[str, Any] + A dictionary containing: + - latency: The measured latency of this kernel + - config: The configuration parameters used + - ref_latency: The reference latency for comparison + """ + if self.latency is None: + raise ValueError("Tuning results are not available. Please tune the kernel first.") + + return { + "latency": self.latency, + "config": self.config, + "ref_latency": self.ref_latency, + } + + @property + def out_idx(self) -> list[int]: + return self.adapter.result_idx + + @property + def params(self) -> list[KernelParam]: + return self.artifact.params if self.artifact else self.adapter.params + + @property + def kernel_source(self) -> str: + return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source + + @property + def host_source(self) -> str: + return str(self.artifact.host_mod) if self.artifact else "" + + def export_library(self, kernel_file: str) -> None: + """ + Exports the compiled kernel function to a shared library file. + + Parameters + ---------- + kernel_file : str + The path to the shared library file to create. + """ + # rt_module: tvm.runtime.Module = None + # rt_params: dict = None + # adapter: BaseKernelAdapter = None + # torch_function: Callable = None + # rt_module: use export_library to export + # rt_params: use cloudpickle to serialize + + # Export the compiled kernel function to a shared library file. + self.rt_module.export_library(kernel_file) + + def _get_ptx(self, verbose: bool | None = None) -> str: + """ + Compile and return PTX for the current kernel (CUDA only). + + Parameters + ---------- + verbose : Optional[bool] + Whether to enable verbose NVRTC logs. Defaults to self.verbose. + + Returns + ------- + str + The compiled PTX text. + """ + if not is_cuda_target(self.target): + raise ValueError("PTX is only available for CUDA targets.") + # Prefer NVCC for PTX generation via contrib helper + code = self.get_kernel_source() + if verbose is None: + verbose = self.verbose + # Ensure target is set so nvcc picks correct arch via Target.current() + with self.target: + return tl_nvcc.get_ptx_from_source(code, compile_flags=self.compile_flags, verbose=verbose) + + def show_ptx(self) -> None: + """ + Print compiled PTX for the kernel (CUDA only). + + Examples + -------- + >>> jit_kernel.show_ptx() + """ + try: + ptx = self._get_ptx() + print(ptx) + except Exception as e: + logger.error(f"Failed to show PTX: {e}") + + def export_ptx(self, path: str) -> None: + """ + Export compiled PTX to a file (CUDA only). + + Parameters + ---------- + path : str + Destination file path to write PTX. + + Examples + -------- + >>> jit_kernel.export_ptx("/tmp/kernel.ptx") + """ + if not path: + raise ValueError("path must be provided to export PTX") + try: + ptx = self._get_ptx() + dir_path = os.path.dirname(path) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + with open(path, "w") as f: + f.write(ptx) + logger.info(f"PTX saved to {os.path.abspath(path)}") + except Exception as e: + logger.error(f"Failed to export PTX: {e}") + + def _get_sass(self, verbose: bool | None = None) -> str: + """ + Compile and return SASS for the current kernel (CUDA only). + + Parameters + ---------- + verbose : Optional[bool] + Whether to enable verbose tool logs. Defaults to self.verbose. + + Returns + ------- + str + The disassembled SASS text. + """ + if not is_cuda_target(self.target): + raise ValueError("SASS is only available for CUDA targets.") + code = self.get_kernel_source() + if verbose is None: + verbose = self.verbose + with self.target: + return tl_nvcc.get_sass_from_source(code, compile_flags=self.compile_flags, verbose=verbose) + + def show_sass(self) -> None: + """ + Print disassembled SASS for the kernel (CUDA only). + + Examples + -------- + >>> jit_kernel.show_sass() + """ + try: + sass = self._get_sass() + print(sass) + except Exception as e: + logger.error(f"Failed to show SASS: {e}") + + def export_sass(self, path: str) -> None: + """ + Export disassembled SASS to a file (CUDA only). + + Parameters + ---------- + path : str + Destination file path to write SASS. + + Examples + -------- + >>> jit_kernel.export_sass("/tmp/kernel.sass") + """ + if not path: + raise ValueError("path must be provided to export SASS") + try: + sass = self._get_sass() + dir_path = os.path.dirname(path) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + with open(path, "w") as f: + f.write(sass) + logger.info(f"SASS saved to {os.path.abspath(path)}") + except Exception as e: + logger.error(f"Failed to export SASS: {e}") diff --git a/tilelang/original/tilelang/jit/param.py b/tilelang/original/tilelang/jit/param.py new file mode 100644 index 0000000000000000000000000000000000000000..175a42f362ecf4de976ff9e7ac17c44cf82339a3 --- /dev/null +++ b/tilelang/original/tilelang/jit/param.py @@ -0,0 +1,42 @@ +from typing import ( + Any, + TypeVar, +) +from typing_extensions import ParamSpec + + +# --- Mocking dependencies for the example to run --- +# In your actual code, these would be your real types. +class Program: + """Placeholder for the type returned by the original decorated function.""" + + def __init__(self, data: str): + self.data = data + + def __repr__(self): + return f"Program('{self.data}')" + + +class Kernel: + """Placeholder for the type of the compiled kernel.""" + + def __init__(self, source: str, out_idx: Any): + self.source_code = source + self.out_idx = out_idx + + def get_kernel_source(self) -> str: + return self.source_code + + def __repr__(self): + return f"Kernel('{self.source_code[:20]}...')" + + +# --- End Mocking --- + +# P (Parameters) captures the argument types of the decorated function. +_P = ParamSpec("_P") +# R_prog (Return type of Program) captures the return type of the original decorated function. +# We assume the original function returns something compatible with 'Program'. +_RProg = TypeVar("_RProg", bound=Program) + +__all__ = ["Program", "Kernel", "_P", "_RProg"] diff --git a/tilelang/original/tilelang/language/__init__.py b/tilelang/original/tilelang/language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97f8385814e9c35e41227f6822c9f4cdb28f0dce --- /dev/null +++ b/tilelang/original/tilelang/language/__init__.py @@ -0,0 +1,122 @@ +"""The language interface for tl programs.""" + +from __future__ import annotations + +# from .parser import * +# now is fully compatible with the upstream +# tir script +# TODO(lei): remove this import once the +# upstream tir script is fully compatible +from tvm.script.parser.tir import * +from . import overrides as _overrides # noqa: F401 + +# from .tir import prim_func, macro, # noqa: F401 +from .v2 import * # noqa: F401 +from .tir.ir import * # noqa: F401 +from tilelang.layout import Layout, Fragment # noqa: F401 +from .proxy import ptr, make_tensor # noqa: F401 +from .v2.annot import ( + Buffer, # noqa: F401 + Tensor, # noqa: F401 + StridedTensor, # noqa: F401 + FragmentBuffer, # noqa: F401 + SharedBuffer, # noqa: F401 + LocalBuffer, # noqa: F401 + dyn, # noqa: F401 +) +from .loop import ( + Parallel, # noqa: F401 + Persistent, # noqa: F401 + Pipelined, # noqa: F401 + serial, # noqa: F401 + unroll, # noqa: F401 + Serial, # noqa: F401 + Unroll, # noqa: F401 +) +from .frame import has_let_value, get_let_value # noqa: F401 +from .math_intrinsics import * # noqa: F401 +from .kernel import ( + Kernel, # noqa: F401 + KernelLaunchFrame, # noqa: F401 + get_thread_binding, # noqa: F401 + get_thread_bindings, # noqa: F401 + get_block_binding, # noqa: F401 + get_block_bindings, # noqa: F401 +) +from .warpgroup import ws # noqa: F401 +from .allocate import ( + alloc_var, # noqa: F401 + alloc_local, # noqa: F401 + alloc_shared, # noqa: F401 + alloc_fragment, # noqa: F401 + alloc_barrier, # noqa: F401 + alloc_tmem, # noqa: F401 + alloc_reducer, # noqa: F401 + alloc_descriptor, # noqa: F401 + alloc_wgmma_desc, # noqa: F401 + alloc_tcgen05_smem_desc, # noqa: F401 + alloc_tcgen05_instr_desc, # noqa: F401 + empty, # noqa: F401 +) +from .copy_op import copy, c2d_im2col # noqa: F401 +from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 +from .gemm_op import gemm, gemm_v1, gemm_v2 # noqa: F401 +from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 +from .fill_op import fill, clear # noqa: F401 +from .reduce_op import ( + reduce, # noqa: F401 + reduce_max, # noqa: F401 + reduce_min, # noqa: F401 + reduce_sum, # noqa: F401 + reduce_abssum, # noqa: F401 + reduce_absmax, # noqa: F401 + reduce_bitand, # noqa: F401 + reduce_bitor, # noqa: F401 + reduce_bitxor, # noqa: F401 + cumsum, # noqa: F401 + finalize_reducer, # noqa: F401 + warp_reduce_sum, # noqa: F401 + warp_reduce_max, # noqa: F401 + warp_reduce_min, # noqa: F401 + warp_reduce_bitand, # noqa: F401 + warp_reduce_bitor, # noqa: F401 +) +from .print_op import print, device_assert # noqa: F401 +from .customize import ( + atomic_max, # noqa: F401 + atomic_min, # noqa: F401 + atomic_add, # noqa: F401 + atomic_addx2, # noqa: F401 + atomic_addx4, # noqa: F401 + dp4a, # noqa: F401 + clamp, # noqa: F401 + reshape, # noqa: F401 + view, # noqa: F401 + atomic_load, # noqa: F401 + atomic_store, # noqa: F401 + loop_break, # noqa: F401 +) +from .logical import any_of, all_of # noqa: F401 +from .builtin import * # noqa: F401 +from .builtin import __ldg as __ldg # noqa: F401 + +from .utils import index_to_coordinates # noqa: F401 + +from .symbolics import dynamic, symbolic # noqa: F401 +from .annotations import ( # noqa: F401 + use_swizzle, + annotate_layout, + annotate_safe_value, + annotate_l2_hit_ratio, + annotate_restrict_buffers, +) + +from .random import ( + rng_init, # noqa: F401 + rng_rand, # noqa: F401 +) + + +def import_source(source: str | None = None): + # source is the source code to be imported + return block_attr({"pragma_import_c": source}) if source is not None else None diff --git a/tilelang/original/tilelang/language/allocate.py b/tilelang/original/tilelang/language/allocate.py new file mode 100644 index 0000000000000000000000000000000000000000..e9338fa6efc8c49984fd042c839bfd3d7b66976f --- /dev/null +++ b/tilelang/original/tilelang/language/allocate.py @@ -0,0 +1,281 @@ +"""Memory allocation utilities for Tile-AI programs. + +This module provides a set of functions for allocating different types of memory buffers +in Tile-AI programs. It wraps TVM's buffer allocation functionality with convenient +interfaces for different memory scopes. + +Available allocation functions: + - alloc_shared: Allocates shared memory buffers for inter-thread communication + - alloc_local: Allocates local memory buffers for thread-private storage + - alloc_fragment: Allocates fragment memory buffers for specialized operations + - alloc_var: Allocates single-element variable buffers + +Each function takes shape and dtype parameters and returns a TVM buffer object +with the appropriate memory scope. +""" + +from __future__ import annotations +from typing import TypeVar, overload, Literal, Callable + +# Python 3.9 compatibility for advanced typing features (PEP 646) +try: + from typing import TypeVarTuple, Unpack # type: ignore[attr-defined] +except Exception: + from typing_extensions import TypeVarTuple, Unpack # type: ignore +from tilelang import tvm as tvm +from tvm.script import tir as T +from tvm.tir import PrimExpr +from tvm.script.parser.tir import block_attr +from tvm.tir.buffer import Buffer +from tvm.tir.expr import FloatImm, IntImm +from .v2 import dtypes as _dtypes +from .v2.dtypes import dtype as tl_dtype +from .v2.builder import OutTensor +from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer + +_Shapes = TypeVarTuple("_Shapes") +_DType = TypeVar("_DType") + + +def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: + """Allocate a shared memory buffer for inter-thread communication. + + Args: + shape (tuple): The shape of the buffer to allocate + dtype (str): The data type of the buffer (e.g., 'float32', 'int32') + scope (str, optional): The memory scope. Defaults to "shared.dyn" + + Returns: + T.Buffer: A TVM buffer object allocated in shared memory + """ + if dtype == "bool": + # lei: This is a hack to handle bool type. + # Because tilelang's merge smem pass cannot merge bool type currently. + scope = "shared" + return T.alloc_buffer(shape, dtype, scope=scope) + + +def alloc_local(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]: + """Allocate a local memory buffer for thread-private storage. + + Args: + shape (tuple): The shape of the buffer to allocate + dtype (str): The data type of the buffer (e.g., 'float32', 'int32') + scope (str, optional): The memory scope. Defaults to "local" + + Returns: + T.Buffer: A TVM buffer object allocated in local memory + """ + return T.alloc_buffer(shape, dtype, scope=scope) + + +def alloc_fragment( + shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local.fragment" +) -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]: + """Allocate a fragment memory buffer for specialized operations. + + Args: + shape (tuple): The shape of the buffer to allocate + dtype (str): The data type of the buffer (e.g., 'float32', 'int32') + scope (str, optional): The memory scope. Defaults to "local.fragment" + + Returns: + T.Buffer: A TVM buffer object allocated in fragment memory + """ + return T.alloc_buffer(shape, dtype, scope=scope) + + +@overload +def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = "local.var") -> Buffer: ... + + +@overload +def alloc_var(dtype: str, scope: str = "local.var", *, init: PrimExpr | int | float | None = None) -> Buffer: ... + + +def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): + """Allocate a single-element variable buffer. + + Args: + dtype (str): The data type of the buffer (e.g., 'float32', 'int32') + *args: Optional positional arguments. A single positional string is treated + as the scope for backward compatibility. A single non-string positional + argument (or keyword ``init``) specifies the initializer. When two + positional arguments are provided, they are interpreted as + ``(init, scope)``. + scope (str, optional): The memory scope. Defaults to "local.var". + Use as keyword argument for clarity when also providing an initializer. + init (PrimExpr, optional): The optional initializer value. When provided, + the generated code will initialize the variable with this value instead + of defaulting to zero. + Examples: + a = T.alloc_var('int32', 1) # var with init 1 + a = T.alloc_var('int32', 'local.var') # var with local.var scope + a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope + a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope + a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope + Returns: + T.Buffer: A TVM buffer object allocated as a single-element variable + """ + parsed_scope = scope + parsed_init = init + + if len(args) == 1: + arg = args[0] + if isinstance(arg, str) and parsed_init is None and scope == "local.var": + parsed_scope = arg + else: + if parsed_init is not None: + raise TypeError("Initializer specified multiple times in alloc_var.") + parsed_init = arg + elif len(args) == 2: + if parsed_init is not None: + raise TypeError("Initializer specified multiple times in alloc_var.") + parsed_init, parsed_scope_arg = args + if not isinstance(parsed_scope_arg, str): + raise TypeError("Scope must be provided as a string in alloc_var.") + parsed_scope = parsed_scope_arg + elif len(args) > 2: + raise TypeError(f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.") + + if not isinstance(parsed_scope, str): + raise TypeError("Scope must be a string in alloc_var.") + + buffer = T.alloc_buffer([1], dtype, scope=parsed_scope) + if parsed_init is not None: + if isinstance(parsed_init, (int, float, IntImm, FloatImm)): + block_attr({"tl.local_var_init": {buffer.data: tl_dtype(dtype)(parsed_init)}}) + else: + T.buffer_store(buffer, parsed_init, 0) + return buffer + + +def alloc_barrier(arrive_count: int): + """Allocate a barrier buffer. + + Args: + arrive_count (int): The number of threads that need to arrive at the barrier + + Returns: + T.Buffer: A TVM buffer object allocated as a barrier + """ + return T.alloc_buffer([arrive_count], _dtypes.uint64, scope="shared.barrier") + + +def alloc_tmem(shape, dtype): + """ + Allocate a Tensor Memory (TMEM) buffer for use with 5th generation Tensor Core operations (e.g., TCGEN5.MMA). + + TMEM is a dedicated on-chip memory introduced in Hopper GPUs, designed to reduce register pressure and enable asynchronous, single-threaded MMA operations. It is organized as a 2D array of 512 columns by 128 rows (lanes), with each cell being 32 bits. Allocation is performed in units of columns, and every lane of a column is allocated together. + + Key properties and requirements: + - The number of columns allocated must be a power of 2 and at least 32. + - TMEM allocations are dynamic and must be explicitly deallocated. + - Both allocation and deallocation must be performed by the same warp. + - The base address of the TMEM allocation is stored in shared memory and used as the offset for TCGEN5.MMA accumulator tensors. + - Only TCGEN5.MMA and specific TMEM load/store instructions can access TMEM; all pre-processing must occur before data is loaded into TMEM, and all post-processing after data is retrieved. + - The number of columns allocated should not increase between any two allocations in the execution order within the CTA. + + Args: + num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512. + + Returns: + T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations. + + Note: + - TMEM is only available on supported architectures (e.g., Hopper and later). + - The buffer returned should be used according to TMEM access restrictions and deallocated appropriately. + """ + + assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation" + return T.alloc_buffer(shape, dtype, scope="shared.tmem") + + +def alloc_reducer(shape, dtype, op="sum", replication=None): + """ + Allocate a reducer buffer. + + Modifications needs to conform with `op`, + such as `op="sum"` requires `reducer[...] += ...` and + `op="max"` requires `reducer[...] = T.max(reducer[...], ...)`. + + Only after T.fill with proper initializer the reduction may begin; + only after T.finalize_reducer the partial results will be available. + + For `op="sum"`, filled value must be 0; for min and max, the filled initializer will become max or min clamper correspondingly. + You may want to use `T.max_value` for min and `T.min_value` for max. + + Args: + shape (tuple): The shape of the buffer to allocate + dtype (str): The data type of the buffer (e.g., 'float32', 'int32') + op (str): The reduce operation corresponded with the reducer + replication (str | None): Replication strategy, can be "all" or "none". Defaults to not specified, and the compiler will do whatever it want. + + Returns: + T.Buffer: A TVM buffer object allocated in thread-private storage, available to reduce values in T.Parallel loops. + """ + + assert op in ["sum", "max", "min"] + # TODO: support automatic layout + if replication is None: + replication = "none" + assert replication in ["all", "none"] + + reducer = T.alloc_buffer(shape, dtype, scope="local.fragment") + block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}}) + + return reducer + + +DescKind = Literal["wgmma", "tcgen05_smem", "tcgen05_instr"] + + +def alloc_descriptor( + kind: DescKind = "wgmma", + dtype: str = _dtypes.uint64, +): + """Allocate a descriptor buffer for WGMMA and TCGEN5.MMA. + + Args: + kind: The descriptor kind, one of "wgmma", "tcgen05" ("utcmma" as alias). + + Returns: + T.Buffer: A TVM buffer object allocated as a descriptor + """ + + scope = "local.descriptor." + kind + # Buffer naming via `name` is not supported by this TVM builder signature; + # keep parameter for forward-compat, but do not pass it. + return T.alloc_buffer([1], dtype, scope=scope) + + +def alloc_wgmma_desc(dtype: str = _dtypes.uint64): + return alloc_descriptor("wgmma", dtype=dtype) + + +def alloc_tcgen05_smem_desc(dtype: str = _dtypes.uint64): + return alloc_descriptor("tcgen05_smem", dtype=dtype) + + +def alloc_tcgen05_instruction_desc(dtype: str = _dtypes.uint32): + return alloc_descriptor("tcgen05_instr", dtype=dtype) + + +# Alias: short name consistent with imports +def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32): + return alloc_tcgen05_instruction_desc(dtype) + + +@overload +def empty(shape: tuple[Unpack[_Shapes]], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... + + +def empty(*shape: Unpack[_Shapes], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + return OutTensor(shape[0], dtype) + elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): + return OutTensor(shape[0], shape[1]) + elif all([isinstance(x, (int, PrimExpr)) for x in shape]): + return OutTensor(shape, dtype) + else: + raise RuntimeError(f"Invalid shape {shape}") diff --git a/tilelang/original/tilelang/language/annotations.py b/tilelang/original/tilelang/language/annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..43ca9c05119fe6667b1b4d9fd363320e53110150 --- /dev/null +++ b/tilelang/original/tilelang/language/annotations.py @@ -0,0 +1,82 @@ +"""Annotation helpers exposed on the TileLang language surface.""" + +from typing import Callable + +from tilelang.layout import Layout +from tvm.script.parser.tir import attr, block_attr +from tvm.tir import FloatImm + +__all__ = [ + "use_swizzle", + "annotate_layout", + "annotate_safe_value", + "annotate_l2_hit_ratio", + "annotate_restrict_buffers", +] + + +def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): + """Annotate a kernel to use a specific threadblock swizzle pattern.""" + device_func = "rasterization2DRow" if order == "row" else "rasterization2DColumn" + if not enable: + return None + return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>") + + +def annotate_layout(layout_map: dict): + """Annotate the layout of the buffer.""" + _layout_map = {} + for buffer, layout in layout_map.items(): + if isinstance(layout, Layout): + _layout_map[buffer.data] = layout + elif isinstance(layout, Callable): + _layout_map[buffer.data] = Layout(buffer.shape, layout) + else: + raise ValueError(f"Invalid layout: {layout}") + + return block_attr({"layout_map": _layout_map}) + + +def annotate_safe_value(safe_value_map: dict): + """Annotate the safe value of the buffer.""" + _safe_value_map = {} + for buffer, safe_value in safe_value_map.items(): + _safe_value_map[buffer.data] = safe_value + return block_attr({"safe_value_map": _safe_value_map}) + + +def annotate_l2_hit_ratio(l2_hit_ratio_map: dict): + """Annotate the L2 hit ratio of the buffer.""" + _l2_hit_ratio_map = {} + for buffer, hit_ratio in l2_hit_ratio_map.items(): + assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers" + _l2_hit_ratio_map[buffer.data] = FloatImm("float32", float(hit_ratio)) + return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map}) + + +def annotate_restrict_buffers(*buffers): + """Mark the given buffer parameters as non-restrict. + + This annotation tells codegen to omit the `__restrict__` qualifier for the + specified kernel buffer parameters. Use this when two (or more) buffers may + alias, for example overlapping slices from the same base tensor. + + Example + ------- + >>> @T.prim_func + ... def buggy_kernel(x: T.Tensor((N,), T.float32), + ... y: T.Tensor((N,), T.float32)): + ... T.annotate_restrict_buffers(x, y) + ... with T.Kernel(N, threads=32) as pid: + ... y[pid] = x[pid] + 1 + """ + if not buffers: + return None + data_vars = [] + for buf in buffers: + try: + data_vars.append(buf.data) + except Exception as e: + raise TypeError(f"annotate_restrict_buffers expects Buffer arguments, got {type(buf)}") from e + # Also return as block attribute (root block exists by default) for readability/tools. + return block_attr({"tl.non_restrict_params": data_vars}) diff --git a/tilelang/original/tilelang/language/ast/__init__.py b/tilelang/original/tilelang/language/ast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab6249b14484ca534cdc3ee3cc5bdb6ce6c2767 --- /dev/null +++ b/tilelang/original/tilelang/language/ast/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +"""Package tvm.script.ir_builder.tir""" + +from .ir import * # noqa: F401 +from .ir import boolean as bool # noqa: F401 +from .ir import buffer as Buffer # noqa: F401 + +from tvm.script.ir_builder.tir import frame # noqa: F401 diff --git a/tilelang/original/tilelang/language/ast/_ffi_api.py b/tilelang/original/tilelang/language/ast/_ffi_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc74762a7089946a4c055d4623e506dce7985e7 --- /dev/null +++ b/tilelang/original/tilelang/language/ast/_ffi_api.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +"""FFI APIs""" + +import tvm.ffi + +tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/tilelang/original/tilelang/language/ast/ir.py b/tilelang/original/tilelang/language/ast/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4faad43a3186efe32a2ff657ba21cc067dfa45 --- /dev/null +++ b/tilelang/original/tilelang/language/ast/ir.py @@ -0,0 +1,2225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +# ruff: noqa +"""IRBuilder for TIR""" + +import functools +import inspect +from numbers import Integral +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + +import numpy as np # type: ignore + +from tvm import tir +from tvm import ir +from tvm.ir import Type +from tvm.ir.base import deprecated +from tvm.runtime import String, convert, ndarray +from tvm.target import Target + +# pylint: disable=unused-import +from tvm.target.codegen import llvm_lookup_intrinsic_id +from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr +from tvm.tir import op as _tir_op +from tvm.tir import type_annotation + +# import tir.expr for direct ir construction to pass structural_equal comparison +from tvm.tir.expr import ( + EQ, + GE, + GT, + LE, + LT, + NE, + Add, + And, + Broadcast, + BufferLoad, + Call, + CallEffectKind, + Cast, + CommReducer, + Div, + FloatImm, + FloorDiv, + FloorMod, + IntImm, + IterVar, + Max, + Min, + Mod, + Mul, + Not, + Or, + ProducerLoad, + Ramp, + Reduce, + Select, + Shuffle, + SizeVar, + StringImm, + Sub, + Var, +) +from tvm.tir.generic import cast + +from . import _ffi_api +from tvm.script.ir_builder.tir import frame + +# pylint: enable=unused-import + + +def buffer( + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + dtype: str = T.float32, + data: Var = None, + strides: List[PrimExpr] = None, + elem_offset: PrimExpr = None, + scope: str = "global", + align: int = 0, + offset_factor: int = 0, + buffer_type: str = "", + axis_separators: List[int] = None, +) -> Buffer: + """The buffer declaration function. + + Parameters + ---------- + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : Buffer + The declared buffer. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is not None: + strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] + else: + strides = [] + return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member + shape, + dtype, + "", + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +@deprecated("T.buffer_decl(...)", "T.Buffer(...)") +def buffer_decl(*args, **kwargs): + return buffer(*args, **kwargs) + + +def prim_func(is_private: bool = False) -> frame.PrimFuncFrame: + """The primitive function statement. + + Parameters + ---------- + is_private : bool + Whether the PrimFunc is annotated as private + (if yes, it does not have a global symbol assigned; + otherwise, the global symbol is the PrimFunc's name) + + Returns + ------- + res : frame.PrimFuncFrame + The PrimFuncFrame. + """ + return _ffi_api.PrimFunc(is_private) # type: ignore[attr-defined] # pylint: disable=no-member + + +def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, Buffer]: + """The PrimFunc arguments adding function. + + Parameters + ---------- + name : str + The name of the argument. + + var : Union[Var, Buffer] + The argument of Var or Buffer. + + Returns + ------- + res : Union[Var, Buffer] + The argument. + """ + return _ffi_api.Arg(name, obj) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_name(name: str) -> None: + """The PrimFunc naming statement. + + Parameters + ---------- + name : str + The name of the PrimFunc. + """ + _ffi_api.FuncName(name) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_attr(attrs: Dict[str, Any]) -> None: + """The PrimFunc annotation statement. + + Parameters + ---------- + attrs : Dict[str, Any] + The annotations of the PrimFunc. + """ + _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret(ret_type: Type) -> Type: + """The PrimFunc return type statement. + + Parameters + ---------- + ret_type : Type + The return type of the PrimFunc. + + Returns + ------- + res : Type + The return type. + """ + return _ffi_api.FuncRet(ret_type) # type: ignore[attr-defined] # pylint: disable=no-member + + +def match_buffer( + param: Union[Var, BufferLoad, BufferRegion], + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None, + dtype: str = T.float32, + data: Var = None, + strides: List[PrimExpr] = None, + elem_offset: PrimExpr = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: List[int] = None, +) -> Buffer: + """The buffer match function. + + Note + ---- + This function will perform different behavior, depending on the type of param. + If the param is a var in function parameter, it will create a buffer from DLTensor. + Else if the param is a subregion of other buffers, then create a subregion match inside a block. + + Example + ------- + Match buffer from function parameter + .. code-block:: python + A = T.match_buffer(a, (128, 128), dtype=T.float32) + + Match buffer from Buffer subregion + .. code-block:: python + A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype=T.float32) + + Parameters + ---------- + param : Union[Var, BufferLoad, BufferRegion] + The parameter of the PrimFunc to match. + + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : Buffer + The matched buffer. + """ + if shape is None: + if isinstance(param, BufferRegion): + dtype = param.buffer.dtype + shape = [region.extent for region in param.region] + else: + raise ValueError("Shape must be specified when binding input param") + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is not None: + idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else T.int32 + strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] + else: + strides = [] + return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint: disable=no-member + param, + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: + """The block declaration statement. + + Parameters + ---------- + name : str + The name of the block. + + no_realize : bool + The flag whether to construct BlockRealize or Block. + + Returns + ------- + res : frame.BlockFrame + The BlockFrame. + """ + return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member + + +def init() -> frame.BlockInitFrame: + """The block initialization statement. + + Returns + ------- + res : frame.BlockInitFrame + The BlockInitFrame. + """ + return _ffi_api.Init() # type: ignore[attr-defined] # pylint: disable=no-member + + +def where(predicate: Union[PrimExpr, int]) -> None: + """The block predicate statement. + + Parameters + ---------- + predicate : Union[PrimExpr, Literal[0, 1]] + The predicate condition. + """ + if isinstance(predicate, bool): + predicate = IntImm("bool", predicate) + if isinstance(predicate, int): + if predicate in [0, 1]: + predicate = IntImm("bool", predicate) + else: + raise ValueError(f"Invalid value for predicate: {predicate}") + _ffi_api.Where(predicate) # type: ignore[attr-defined] # pylint: disable=no-member + + +def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: + """The block buffer region reading statement. + + Parameters + ---------- + buffer_slices : List[Union[BufferRegion, BufferLoad]] + The array of buffer regions to read. + """ + if len(buffer_slices) == 1: + if isinstance(buffer_slices[0], tuple): + buffer_slices = list(buffer_slices[0]) + elif isinstance(buffer_slices[0], list): + buffer_slices = buffer_slices[0] # type: ignore[assignment] + else: + buffer_slices = [buffer_slices[0]] + else: + buffer_slices = list(buffer_slices) # type: ignore[assignment] + _ffi_api.Reads(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member + + +def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: + """The block buffer region writing statement. + + Parameters + ---------- + buffer_slices : List[Union[BufferRegion, BufferLoad]] + The array of buffer regions to write. + """ + if len(buffer_slices) == 1: + if isinstance(buffer_slices[0], tuple): + buffer_slices = list(buffer_slices[0]) + elif isinstance(buffer_slices[0], list): + buffer_slices = buffer_slices[0] # type: ignore[assignment] + else: + buffer_slices = [buffer_slices[0]] + else: + buffer_slices = list(buffer_slices) # type: ignore[assignment] + _ffi_api.Writes(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member + + +def block_attr(attrs: Dict[str, Any]) -> None: + """The block annotation statement. + + Parameters + ---------- + attrs : Dict[str, Any] + The annotation of the block. + """ + return _ffi_api.BlockAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def alloc_buffer( + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + dtype: str = T.float32, + data: Var = None, + strides: List[PrimExpr] = None, + elem_offset: PrimExpr = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: List[int] = None, +) -> Buffer: + """The buffer allocation function. + + Parameters + ---------- + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : Buffer + The allocated buffer. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is not None: + strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] + else: + strides = [] + return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: + """The range constructor. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr]] + The domain. + + Returns + ------- + res : Range + The Range. + """ + if isinstance(dom, ir.Range): + return dom + if isinstance(dom, (list, tuple)): + return ir.Range(dom[0], dom[1]) + if hasattr(dom, "dtype"): + return ir.Range(IntImm(dom.dtype, 0), dom) + return ir.Range(0, dom) + + +class axis: # pylint: disable=invalid-name + """The axis class""" + + @staticmethod + def spatial( + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = T.int32, + ) -> Var: + """The spatial block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint: disable=no-member + _as_range(dom), binding, dtype + ) + + @staticmethod + def reduce( + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = T.int32, + ) -> Var: + """The reduced block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint: disable=no-member + _as_range(dom), binding, dtype + ) + + @staticmethod + def scan( + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = T.int32, + ) -> Var: + """The scanning block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint: disable=no-member + _as_range(dom), binding, dtype + ) + + @staticmethod + def opaque( + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = T.int32, + ) -> Var: + """The opaque block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint: disable=no-member + _as_range(dom), binding, dtype + ) + + @staticmethod + def remap(kinds: str, bindings: List[PrimExpr], dtype: str = T.int32) -> Union[List[Var], Var]: + """The block axis remapping function. + + Parameters + ---------- + kinds : str + The types of the iteration variables. + + bindings : List[PrimExpr] + The binding values of the iteration variables. + + dtype : str + The data types of the iteration variables. + + Returns + ------- + res : Var + The iteration variables. + """ + iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member + kinds, bindings, dtype + ) + return iter_vars[0] if len(iter_vars) == 1 else iter_vars + + S = spatial # pylint: disable=invalid-name + R = reduce # pylint: disable=invalid-name + + +def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: + """The serial For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + + +def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: + """The parallel For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + + +def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: + """The vectorized For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + + +def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: + """The unrolled For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + + +def thread_binding( + start: PrimExpr, + stop: PrimExpr = None, + thread: str = None, + *, + annotations: Dict[str, Any] = None, +) -> frame.ForFrame: + """The thread-binding For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + thread : str + The thread for loop variable to bind. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if thread is None: + if not isinstance(stop, str): + raise ValueError("Thread cannot be None for thread_binding") + thread = stop + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + elif stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member + start, stop, thread, annotations + ) + + +def grid(*extents: PrimExpr) -> frame.ForFrame: + """The grid For statement. + + Parameters + ---------- + extents : PrimExpr + The extents of the iteration. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name + """Create an assertion statement. + + Parameters + ---------- + condition : PrimExpr + The PrimExpr to test. + + message : str + The output error message when the assertion fails. + + Returns + ------- + res : frame.AssertFrame + The result AssertFrame. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member + + +def LetStmt( # pylint: disable=invalid-name + value: PrimExpr, + type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name + *, + var: Optional[Var] = None, # pylint: disable=redefined-outer-name +) -> frame.LetFrame: + """Create a LetStmt binding + + Parameters + ---------- + value : PrimExpr + The value to be bound. + type_annotation : Optional[Type] = None + The type annotation of the let binding. Usually it is used for fine-grained var typing, + particularly, PointerType. + var : Optional[Var] = None + The variable to bind. If not specified, a new variable will be created. + + Returns + ------- + let_frame : frame.LetFrame + The result LetFrame. + """ + if type_annotation is not None: + if callable(type_annotation): + type_annotation = type_annotation() + if isinstance(type_annotation, Var): + type_annotation = type_annotation.type_annotation + return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Let( # pylint: disable=invalid-name + expr: PrimExpr, + where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name +) -> PrimExpr: + """Create a Let expression binding""" + assert len(where) == 1, "T.Let only allows `where` to have exactly one element" + var, value = list(where.items())[0] # pylint: disable=redefined-outer-name + return tir.Let(var, value, expr) + + +def let( + v: Var, + value: PrimExpr, + body: PrimExpr = None, +) -> frame.LetFrame: + """Create a new let binding. + + Parameters + ---------- + v : Var + The variable to bind. + + value : PrimExpr + The value to be bound. + + body : PrimExpr + The body expression, None will be used if it was not specified. + + Returns + ------- + res : frame.LetFrame + The result LetFrame. + """ + + @deprecated("T.let", "T.Let") + def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: + return tir.Let(v, value, body) + + @deprecated("T.let", "T.LetStmt") + def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: + return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member + + if body is None: + return let_stmt(v, value) + else: + return let_expr(v, value, body) + + +def realize( + buffer_slice: BufferRegion, + storage_scope: str, + condition: PrimExpr = True, +) -> frame.RealizeFrame: + """Create a realization. + + Parameters + ---------- + buffer_slice : BufferRegion + The region of buffer access. + + storage_scope : str + The storage scope associated with this realization. + + condition: PrimExpr + The condition expression, the default is True. + + Returns + ------- + res : frame.RealizeFrame + The result RealizeFrame. + """ + return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member + buffer_slice, storage_scope, condition + ) + + +def allocate( + extents: List[PrimExpr], + dtype: str, + scope: str = "global", + condition: PrimExpr = None, + annotations=None, +) -> frame.AllocateFrame: + """Allocate node. + + Parameters + ---------- + extents : List[PrimExpr] + The extents of the allocate. + + dtype : str + The data type of the buffer. + + scope : str + The storage scope. + + condition : PrimExpr + The condition. + + annotations: Optional[Mapping[str, Object]] + Additional annotation hints. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member + extents, dtype, scope, condition, annotations + ) + + +def allocate_const( + data: List[PrimExpr], + dtype: str, + extents: List[PrimExpr], + annotations=None, +) -> frame.AllocateConstFrame: + """Allocate constant node. + + Parameters + ---------- + data : List[PrimExpr] + The data associated with the constant. + + dtype : str + The data type of the buffer. + + extents : List[PrimExpr] + The extents of the allocate. + + annotations : Optional[Map] + Additional annotations about the allocation. + """ + np_data = np.asarray(data, dtype=dtype) + prod_extent = 1 + for extent in extents: + prod_extent *= extent + prod_shape = 1 + for shape in np_data.shape: + prod_shape *= shape + if prod_extent == prod_shape: + np_data = np_data.reshape(extents) + + return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member + ndarray.array(np_data), dtype, extents, annotations + ) + + +def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: + """Create an attribute node. + + Parameters + ---------- + node : Any + The node to annotate the attribute. + + attr_key : str + Attribute type key. + + value : Union[PrimExpr, str] + The value of the attribute. + + Returns + ------- + res : frame.AttrFrame + The result AttrFrame. + """ + node = convert(node) + value = convert(value) + return _ffi_api.Attr(node, attr_key, value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name + """Create a while node. + + Parameters + ---------- + condition : PrimExpr + The termination condition of the loop. + + Returns + ------- + res : frame.WhileFrame + The result WhileFrame. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.While(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if node. + + Parameters + ---------- + condition : PrimExpr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then. + + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # type: ignore[attr-defined] # pylint: disable=no-member + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else. + + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_buffer( + shape, + dtype=T.float32, + data=None, + strides=None, + elem_offset=None, + scope="global", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, +) -> frame.DeclBufferFrame: + """Create a buffer declaration node. + + Parameters + ---------- + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : frame.DeclBufferFrame + The result DeclBufferFrame. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is not None: + strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] + else: + strides = [] + return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member + shape, + dtype, + "", + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def launch_thread( + thread: Union[IterVar, str], # pylint: disable=redefined-outer-name + extent: PrimExpr, +) -> frame.LaunchThreadFrame: + """Launch a thread. + + Parameters + ---------- + thread : Union[IterVar, str] + The iteration variable. + + extent : PrimExpr + The extent of environment thread. + + Returns + ------- + res : frame.LaunchThreadFrame + The result LaunchThreadFrame. + + Examples + -------- + + .. code-block:: python + + from tvm.script.ir_builder import tir as T + brow = T.env_thread("blockIdx.y") + T.launch_thread(brow, 1) + + """ + + if isinstance(thread, str): + thread = String(thread) + return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member + + +def env_thread(thread_tag: str, dtype: str = T.int32) -> IterVar: + """Bind a var to thread env + + Parameters + ---------- + thread_tag : str + The thread type tag. + + dtype : str + The data type of the thread env. + + Returns + ------- + res : IterVar + The result iteration variable gets bound to the thread env. + + """ + return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined] # pylint: disable=no-member + + +def buffer_store( + buffer: Buffer, # pylint: disable=redefined-outer-name + value: PrimExpr, + indices: List[Union[PrimExpr, slice]], +) -> None: + """Buffer store node. + + Parameters + ---------- + buffer : Buffer + The buffer. + + value : PrimExpr + The value to be stored. + + indices : List[Union[PrimExpr, slice]] + The indices location to be stored. + """ + from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel + + if not isinstance(indices, (list, tuple, ir.Array)): + indices = [indices] + + expr_indices = [] + for index in indices: + if isinstance(index, slice): + step = 1 if index.step is None else index.step + lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step) + if lanes == 1: + expr_indices.append(index.start) + else: + expr_indices.append(ramp(index.start, step, lanes)) + else: + expr_indices.append(index) + if isinstance(value, bool) and buffer.dtype == "bool": + value = IntImm("bool", value) + return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member + buffer, value, expr_indices + ) + + +def prefetch( + buffer: Buffer, # pylint: disable=redefined-outer-name + bounds: List[ir.Range], +) -> None: + """The prefetch hint for a buffer. + + Parameters + ---------- + buffer : Buffer + The buffer to be prefetched. + bounds : List[Range] + The range to be prefetched. + """ + return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member + + +def evaluate(value: PrimExpr) -> None: + """Evaluate the input expression. + + Parameters + ---------- + value: PrimExpr + The input expression to evaluate. + """ + if isinstance(value, str): + value = StringImm(value) + if isinstance(value, bool): + value = cast(value, "bool") + return _ffi_api.Evaluate(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_gen(name: str): + """Generate a function for each PrimExpr dtype. + + Parameters + ---------- + name: str + The ffi function name to call. + """ + + def func( + expr: Union[ + None, + PrimExpr, + Literal["inf", "-inf", "nan"], + int, + float, + ] = None, + *, + is_size_var: bool = False, + ) -> PrimExpr: + if isinstance(expr, str): + expr = float(expr) + return getattr(_ffi_api, name)(expr, is_size_var) + + return func + + +# pylint: disable=invalid-name +int8 = func_gen(("Int8")) +int16 = func_gen(("Int16")) +int32 = func_gen(("Int32")) +int64 = func_gen(("Int64")) +int8x4 = func_gen(("Int8x4")) +int16x4 = func_gen(("Int16x4")) +int32x4 = func_gen(("Int32x4")) +int64x4 = func_gen(("Int64x4")) +int8x8 = func_gen(("Int8x8")) +int16x8 = func_gen(("Int16x8")) +int32x8 = func_gen(("Int32x8")) +int64x8 = func_gen(("Int64x8")) +int8x16 = func_gen(("Int8x16")) +int16x16 = func_gen(("Int16x16")) +int32x16 = func_gen(("Int32x16")) +int64x16 = func_gen(("Int64x16")) +int8x32 = func_gen(("Int8x32")) +int16x32 = func_gen(("Int16x32")) +int32x32 = func_gen(("Int32x32")) +int64x32 = func_gen(("Int64x32")) +int8x64 = func_gen(("Int8x64")) +int16x64 = func_gen(("Int16x64")) +int32x64 = func_gen(("Int32x64")) +int64x64 = func_gen(("Int64x64")) + +uint8 = func_gen(("UInt8")) +uint16 = func_gen(("UInt16")) +uint32 = func_gen(("UInt32")) +uint64 = func_gen(("UInt64")) +uint8x4 = func_gen(("UInt8x4")) +uint16x4 = func_gen(("UInt16x4")) +uint32x4 = func_gen(("UInt32x4")) +uint64x4 = func_gen(("UInt64x4")) +uint8x8 = func_gen(("UInt8x8")) +uint16x8 = func_gen(("UInt16x8")) +uint32x8 = func_gen(("UInt32x8")) +uint64x8 = func_gen(("UInt64x8")) +uint8x16 = func_gen(("UInt8x16")) +uint16x16 = func_gen(("UInt16x16")) +uint32x16 = func_gen(("UInt32x16")) +uint64x16 = func_gen(("UInt64x16")) +uint8x32 = func_gen(("UInt8x32")) +uint16x32 = func_gen(("UInt16x32")) +uint32x32 = func_gen(("UInt32x32")) +uint64x32 = func_gen(("UInt64x32")) +uint8x64 = func_gen(("UInt8x64")) +uint16x64 = func_gen(("UInt16x64")) +uint32x64 = func_gen(("UInt32x64")) +uint64x64 = func_gen(("UInt64x64")) + +float16 = func_gen(("Float16")) +float32 = func_gen(("Float32")) +float64 = func_gen(("Float64")) +float16x4 = func_gen(("Float16x4")) +float32x4 = func_gen(("Float32x4")) +float64x4 = func_gen(("Float64x4")) +float16x8 = func_gen(("Float16x8")) +float32x8 = func_gen(("Float32x8")) +float64x8 = func_gen(("Float64x8")) +float16x16 = func_gen(("Float16x16")) +float32x16 = func_gen(("Float32x16")) +float64x16 = func_gen(("Float64x16")) +float16x32 = func_gen(("Float16x32")) +float32x32 = func_gen(("Float32x32")) +float64x32 = func_gen(("Float64x32")) +float16x64 = func_gen(("Float16x64")) +float32x64 = func_gen(("Float32x64")) +float64x64 = func_gen(("Float64x64")) + +float8_e4m3 = func_gen(("E4M3Float8")) +float8_e4m3x4 = func_gen(("E4M3Float8x4")) +float8_e4m3x8 = func_gen(("E4M3Float8x8")) +float8_e4m3x16 = func_gen(("E4M3Float8x16")) +float8_e4m3x32 = func_gen(("E4M3Float8x32")) +float8_e4m3x64 = func_gen(("E4M3Float8x64")) + +float8_e5m2 = func_gen(("E5M2Float8")) +float8_e5m2x4 = func_gen(("E5M2Float8x4")) +float8_e5m2x8 = func_gen(("E5M2Float8x8")) +float8_e5m2x16 = func_gen(("E5M2Float8x16")) +float8_e5m2x32 = func_gen(("E5M2Float8x32")) +float8_e5m2x64 = func_gen(("E5M2Float8x64")) + +# pylint: enable=invalid-name + + +def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimExpr: + """Construct a new tir.Var with type boolean or cast expression to type boolean. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + + Returns + ------- + res : PrimExpr + The new tir.Var with type boolean or casted expression with type boolean. + """ + return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member + + +def handle(dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: + """Create a TIR var that represents a pointer. + + Parameters + ---------- + dtype: str + The data type of the pointer. + + storage_scope: str + The storage scope of the pointer. + + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + + Returns + ------- + res : PrimExpr + The new tir.Var with type handle or casted expression with type handle. + """ + is_unknown_type = dtype is None + if dtype is None: + dtype = "void" + return _ffi_api.Handle( # type: ignore[attr-defined] # pylint: disable=no-member + dtype, + storage_scope, + is_size_var, + is_unknown_type, + ) + + +def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr: + """Construct a new tir.Var with type void or cast expression to type void. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type void or casted expression with type void. + """ + return _ffi_api.Void(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member + + +@deprecated("T.var", "T.{dtype}") +def var(dtype: str, name: str = "") -> Var: + """Construct a new tir.Var. + + Parameters + ---------- + dtype: str + The dtype of the Var. + + name: str + The name of the Var. + + Returns + ------- + res : Var + The result tir.Var. + """ + return Var(name, dtype) # pylint: disable=no-member + + +def ptr(dtype: str, storage_scope: str = "global", is_size_var: bool = False) -> Var: + """The pointer declaration function. + + Parameters + ---------- + dtype : str + The data type of the pointer. + + storage_scope : str + The storage scope of the pointer. + + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + + Returns + ------- + res : Var + The pointer. + """ + return _ffi_api.Ptr(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member + + +@deprecated("T.buffer_var", "T.handle") +def buffer_var(dtype: str, storage_scope: str = "global") -> Var: + """The pointer declaration function. + + Parameters + ---------- + dtype : str + The data type of the pointer. + + storage_scope : str + The storage scope of the pointer. + + Returns + ------- + res : Var + The pointer. + """ + return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member + + +def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin + """Compute the minimum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api.min(a, b) # type: ignore[attr-defined] # pylint: disable=no-member + + +def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin + """Compute the maximum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member + + +def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar: + """The iteration variable. + + Parameters + ---------- + var : Union[Var, str] + The internal variable that is used for iteration. + + dom : Range + The domain of the iteration. + + iter_type : str + The iteration type. + + thread_tag : str + The thread type tag. + + Returns + ------- + res : IterVar + The iteration variable. + """ + iter_type = getattr(IterVar, iter_type) + return IterVar(dom, v, iter_type, thread_tag) + + +def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: + """ + Create a CommReducer from lambda inputs/outputs and the identities + + Parameters + ---------- + combiner : Callable + A binary function which takes two PrimExpr as input to return a PrimExpr. + + identity : List[PrimExpr] + A list of types of output PrimExpr. + + Returns + ------- + res : CommReducer + The CommReducer. + """ + params = inspect.signature(combiner).parameters + num_args = len(params) + args = [] + for name, i in zip(params.keys(), identity + identity): + if isinstance(i, int): + args.append(Var(name, T.int32)) + else: + args.append(Var(name, i.dtype)) + res = combiner(*args) + if not isinstance(res, tuple): + res = (res,) + return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) + + +def index_map( + mapping: Callable, + *, + inverse_index_map: Optional[Callable] = None, +) -> IndexMap: + """Create a TIR Index mapping""" + return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) + + +def target( + target_config: Union[Dict, str], + host: Optional[Union[Dict, str, Target]] = None, +) -> Target: + """ + Create a target + + Parameters + ---------- + target_config : Union[Dict, str] + The target configuration. + + host : Optional[Union[Dict, str, Target]] + The target configuration. + + Returns + ------- + res : Target + The target. + """ + if not isinstance(target_config, (str, dict)): + raise ValueError(f"T.target expected a config dict or string, but got {type(target_config)}") + if host is not None and not isinstance(host, (str, dict, Target)): + raise ValueError(f"T.target expected the host to be a config dict, string, or T.target, but got {type(host)}") + if isinstance(target_config, dict) and "host" in target_config and host is not None: + raise ValueError( + "T.target expects to either receive the host " + "as part of the target's config dictionary, " + "or as a separate argument, but not both." + ) + return Target(target_config, host) + + +def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name + """ + Create a Range object. + + Parameters + ---------- + begin : PrimExpr + The begin value of the range. + + end : Optional[PrimExpr] + The end value of the range. + """ + return ir.Range(begin, end) + + +class meta_var: # pylint: disable=invalid-name + """A meta variable used in TVMScript metaprogramming. It means that the value of the variable + does not appear in the final TIR, but only stays in the parser. + + Parameters + ---------- + value: Any + The meta variable. + """ + + def __init__(self, value: Any) -> None: + self.value = value + + def __iter__(self): + def f(): + for i in self.value: + yield meta_var(i) + + return f() + + +# pylint: disable=invalid-name + + +def _op_wrapper(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped + + +abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin +acos = _op_wrapper(_tir_op.acos) +acosh = _op_wrapper(_tir_op.acosh) +address_of = _op_wrapper(_tir_op.address_of) +asin = _op_wrapper(_tir_op.asin) +asinh = _op_wrapper(_tir_op.asinh) +atan = _op_wrapper(_tir_op.atan) +atan2 = _op_wrapper(_tir_op.atan2) +atanh = _op_wrapper(_tir_op.atanh) +bitwise_and = _op_wrapper(_tir_op.bitwise_and) +bitwise_not = _op_wrapper(_tir_op.bitwise_not) +bitwise_or = _op_wrapper(_tir_op.bitwise_or) +bitwise_xor = _op_wrapper(_tir_op.bitwise_xor) +ceil = _op_wrapper(_tir_op.ceil) +clz = _op_wrapper(_tir_op.clz) +copysign = _op_wrapper(_tir_op.copysign) +cos = _op_wrapper(_tir_op.cos) +cosh = _op_wrapper(_tir_op.cosh) +erf = _op_wrapper(_tir_op.erf) +exp = _op_wrapper(_tir_op.exp) +exp2 = _op_wrapper(_tir_op.exp2) +exp10 = _op_wrapper(_tir_op.exp10) +floor = _op_wrapper(_tir_op.floor) +ceildiv = _op_wrapper(_tir_op.ceildiv) +floordiv = _op_wrapper(_tir_op.floordiv) +floormod = _op_wrapper(_tir_op.floormod) +fmod = _op_wrapper(_tir_op.fmod) +hypot = _op_wrapper(_tir_op.hypot) +if_then_else = _op_wrapper(_tir_op.if_then_else) +infinity = _op_wrapper(_tir_op.infinity) +isfinite = _op_wrapper(_tir_op.isfinite) +isinf = _op_wrapper(_tir_op.isinf) +isnan = _op_wrapper(_tir_op.isnan) +isnullptr = _op_wrapper(_tir_op.isnullptr) +ldexp = _op_wrapper(_tir_op.ldexp) +likely = _op_wrapper(_tir_op.likely) +log = _op_wrapper(_tir_op.log) +log1p = _op_wrapper(_tir_op.log1p) +log2 = _op_wrapper(_tir_op.log2) +log10 = _op_wrapper(_tir_op.log10) +lookup_param = _op_wrapper(_tir_op.lookup_param) +max_value = _op_wrapper(_tir_op.max_value) +min_value = _op_wrapper(_tir_op.min_value) +nearbyint = _op_wrapper(_tir_op.nearbyint) +nextafter = _op_wrapper(_tir_op.nextafter) +popcount = _op_wrapper(_tir_op.popcount) +pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin +q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) +q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) +ret = _op_wrapper(_tir_op.ret) +round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin +rsqrt = _op_wrapper(_tir_op.rsqrt) +shift_left = _op_wrapper(_tir_op.shift_left) +shift_right = _op_wrapper(_tir_op.shift_right) +sigmoid = _op_wrapper(_tir_op.sigmoid) +sin = _op_wrapper(_tir_op.sin) +sinh = _op_wrapper(_tir_op.sinh) +sqrt = _op_wrapper(_tir_op.sqrt) +tan = _op_wrapper(_tir_op.tan) +tanh = _op_wrapper(_tir_op.tanh) +trunc = _op_wrapper(_tir_op.trunc) +truncdiv = _op_wrapper(_tir_op.truncdiv) +truncmod = _op_wrapper(_tir_op.truncmod) +tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) +tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) +tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) +tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) +tvm_check_return = _op_wrapper(_tir_op.tvm_check_return) +call_packed = _op_wrapper(_tir_op.call_packed) +call_cpacked = _op_wrapper(_tir_op.call_cpacked) +call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) +call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered) +tvm_tuple = _op_wrapper(_tir_op.tvm_tuple) +tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set) +tvm_struct_get = _tir_op.tvm_struct_get +tvm_thread_invariant = _op_wrapper(_tir_op.tvm_thread_invariant) +tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce) +tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync) +tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync) +tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) +tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) +tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +tvm_storage_sync = _tir_op.tvm_storage_sync +tvm_warp_shuffle = _tir_op.tvm_warp_shuffle +tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up +tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down +tvm_warp_activemask = _tir_op.tvm_warp_activemask +ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) +ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) +ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier) +ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count) +ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier) +ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx) +ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier) +create_barriers = _op_wrapper(_tir_op.create_barriers) +assume = _op_wrapper(_tir_op.assume) +undef = _op_wrapper(_tir_op.undef) +TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) +TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) +start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) +end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +anylist_getitem = _op_wrapper(_tir_op.anylist_getitem) +anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) +anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) +anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) +vscale = _op_wrapper(_tir_op.vscale) + + +def _dtype_forward(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + args = (kwargs.pop("dtype"),) + args + return func(*args, **kwargs) + + return wrapped + + +reinterpret = _dtype_forward(_tir_op.reinterpret) +call_extern = _dtype_forward(_tir_op.call_extern) +call_intrin = _dtype_forward(_tir_op.call_intrin) +call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) +call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) +call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) +ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) +ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) +ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) +ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) +ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) +ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) +mma_store = _dtype_forward(_tir_op.mma_store) +mma_fill = _dtype_forward(_tir_op.mma_fill) +vectorlow = _dtype_forward(_tir_op.vectorlow) +vectorhigh = _dtype_forward(_tir_op.vectorhigh) +vectorcombine = _dtype_forward(_tir_op.vectorcombine) +tvm_mfma = _dtype_forward(_tir_op.tvm_mfma) +tvm_mmac = _dtype_forward(_tir_op.tvm_mmac) +tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store) +tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma) +tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store) + +broadcast = Broadcast +ramp = Ramp +fabs = abs +tvm_call_packed = call_packed +tvm_call_cpacked = call_cpacked +tvm_call_packed_lowered = call_packed_lowered +tvm_call_cpacked_lowered = call_cpacked_lowered + +# pylint: enable=invalid-name + +__all__ = [ + "int8", + "int16", + "int32", + "int64", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "int8x8", + "int16x8", + "int32x8", + "int64x8", + "int8x16", + "int16x16", + "int32x16", + "int64x16", + "int8x32", + "int16x32", + "int32x32", + "int64x32", + "int8x64", + "int16x64", + "int32x64", + "int64x64", + "uint8", + "uint16", + "uint32", + "uint64", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "uint8x8", + "uint16x8", + "uint32x8", + "uint64x8", + "uint8x16", + "uint16x16", + "uint32x16", + "uint64x16", + "uint8x32", + "uint16x32", + "uint32x32", + "uint64x32", + "uint8x64", + "uint16x64", + "uint32x64", + "uint64x64", + "float8_e4m3", + "float8_e5m2", + "float16", + "float32", + "float64", + "float8_e4m3x4", + "float8_e5m2x4", + "float16x4", + "float32x4", + "float64x4", + "float8_e4m3x8", + "float8_e5m2x8", + "float16x8", + "float32x8", + "float64x8", + "float8_e4m3x16", + "float8_e5m2x16", + "float16x16", + "float32x16", + "float64x16", + "float8_e4m3x32", + "float8_e5m2x32", + "float16x32", + "float32x32", + "float64x32", + "float8_e4m3x64", + "float8_e5m2x64", + "float16x64", + "float32x64", + "float64x64", + "buffer", + "buffer_decl", + "prim_func", + "arg", + "func_name", + "func_attr", + "func_ret", + "match_buffer", + "block", + "init", + "where", + "reads", + "writes", + "block_attr", + "alloc_buffer", + "axis", + "serial", + "parallel", + "vectorized", + "unroll", + "thread_binding", + "grid", + "Assert", + "realize", + "allocate", + "allocate_const", + "attr", + "While", + "If", + "Then", + "Else", + "decl_buffer", + "launch_thread", + "env_thread", + "buffer_store", + "prefetch", + "customized_code", + "evaluate", + "boolean", + "handle", + "void", + "var", + "ptr", + "min", + "max", + "iter_var", + "comm_reducer", + "index_map", + "target", + "buffer_var", + "abs", + "fabs", + "acos", + "acosh", + "address_of", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bitwise_xor", + "ceil", + "clz", + "copysign", + "cos", + "cosh", + "erf", + "exp", + "exp2", + "exp10", + "floor", + "ceildiv", + "floordiv", + "floormod", + "fmod", + "hypot", + "if_then_else", + "infinity", + "isfinite", + "isinf", + "isnan", + "isnullptr", + "ldexp", + "likely", + "log", + "log1p", + "log2", + "log10", + "lookup_param", + "max_value", + "min_value", + "nearbyint", + "nextafter", + "popcount", + "pow", + "q_multiply_shift", + "q_multiply_shift_per_axis", + "ret", + "reinterpret", + "round", + "rsqrt", + "shift_left", + "shift_right", + "sigmoid", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "trunc", + "truncdiv", + "truncmod", + "tvm_access_ptr", + "tvm_throw_last_error", + "tvm_stack_alloca", + "tvm_stack_make_shape", + "tvm_stack_make_array", + "tvm_check_return", + "call_packed", + "call_cpacked", + "call_packed_lowered", + "call_cpacked_lowered", + "call_extern", + "call_intrin", + "call_llvm_intrin", + "call_llvm_pure_intrin", + "call_pure_extern", + "tvm_tuple", + "tvm_struct_set", + "tvm_struct_get", + "tvm_thread_invariant", + "tvm_thread_allreduce", + "tvm_load_matrix_sync", + "tvm_mma_sync", + "tvm_bmma_sync", + "tvm_fill_fragment", + "tvm_store_matrix_sync", + "tvm_storage_sync", + "tvm_warp_shuffle", + "tvm_warp_shuffle_up", + "tvm_warp_shuffle_down", + "tvm_warp_activemask", + "ptx_mma", + "ptx_mma_sp", + "ptx_wgmma_ss", + "ptx_wgmma_rs", + "ptx_tcgen05_mma_ss", + "ptx_ldmatrix", + "ptx_cp_async", + "ptx_cp_async_bulk", + "ptx_wait_group", + "ptx_commit_group", + "ptx_cp_async_barrier", + "ptx_init_barrier_thread_count", + "ptx_arrive_barrier", + "ptx_arrive_barrier_expect_tx", + "ptx_wait_barrier", + "create_barriers", + "mma_store", + "mma_fill", + "vectorlow", + "vectorhigh", + "vectorcombine", + "tvm_mfma", + "tvm_mmac", + "tvm_mfma_store", + "tvm_rdna_wmma", + "tvm_rdna_wmma_store", + "assume", + "undef", + "tvm_call_packed", + "tvm_call_cpacked", + "tvm_call_packed_lowered", + "tvm_call_cpacked_lowered", + "TVMBackendAllocWorkspace", + "TVMBackendFreeWorkspace", + "start_profile_intrinsic", + "end_profile_intrinsic", + "meta_var", + "anylist_getitem", + "anylist_resetitem", + "anylist_setitem_call_packed", + "anylist_setitem_call_cpacked", + "llvm_lookup_intrinsic_id", + "type_annotation", + "broadcast", + "ramp", + "cast", + # tvm.tir.expr + "Var", + "SizeVar", + "Reduce", + "FloatImm", + "IntImm", + "StringImm", + "Cast", + "Add", + "Sub", + "Mul", + "Div", + "Mod", + "FloorDiv", + "FloorMod", + "Min", + "Max", + "EQ", + "NE", + "LT", + "LE", + "GT", + "GE", + "And", + "Or", + "Not", + "Select", + "BufferLoad", + "ProducerLoad", + "Ramp", + "Broadcast", + "Shuffle", + "Call", + "CallEffectKind", + "let", + "LetStmt", + "Let", + "IterVar", + "CommReducer", + "Range", + "vscale", +] diff --git a/tilelang/original/tilelang/language/atomic.py b/tilelang/original/tilelang/language/atomic.py new file mode 100644 index 0000000000000000000000000000000000000000..a801f75f4c8bf02eb1d3058c810c66994b28d159 --- /dev/null +++ b/tilelang/original/tilelang/language/atomic.py @@ -0,0 +1,391 @@ +"""Atomic operations exposed on the TileLang language surface.""" + +from __future__ import annotations + +import tilelang.language as T +from tvm import ir +from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op +from tilelang.utils.language import to_buffer_region, legalize_pairwise_extents + +_MEMORY_ORDER_ID_MAP = { + "relaxed": 0, + "consume": 1, + "acquire": 2, + "release": 3, + "acq_rel": 4, + "seq_cst": 5, +} + + +def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: + """ + Perform an atomic maximum on the value stored at dst with an optional memory-order. + + If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern. + + Parameters: + dst (Buffer): Destination buffer/address to apply the atomic max. + value (PrimExpr): Value to compare/store atomically. + memory_order (Optional[str]): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst"). + If provided, it is translated to the corresponding numeric memory-order id before the call. + return_prev (bool): If True, return the previous value; if False, return handle (default False). + + Returns: + PrimExpr: A handle/expression representing the issued atomic maximum operation, or the previous value if return_prev is True. + + Examples: + >>> # Basic atomic max operation + >>> counter = T.Tensor([1], "float32", name="counter") + >>> atomic_max(counter, 42.0) + + >>> # With memory ordering + >>> atomic_max(counter, 100.0, memory_order="acquire") + + >>> # Get the previous value + >>> prev_value = atomic_max(counter, 50.0, return_prev=True) + >>> # prev_value now contains the value that was in counter before the max operation + + >>> # Use in parallel reduction to find global maximum + >>> @T.prim_func + >>> def find_max(data: T.Buffer, result: T.Buffer): + >>> for i in T.thread_binding(128, "threadIdx.x"): + >>> atomic_max(result, data[i]) + """ + func_name = "AtomicMaxRet" if return_prev else "AtomicMax" + return_type = dst.dtype if return_prev else "handle" + + if memory_order is None: + return T.call_extern(return_type, func_name, T.address_of(dst), value) + else: + return T.call_extern( + return_type, + func_name, + T.address_of(dst), + value, + _MEMORY_ORDER_ID_MAP[memory_order], + ) + + +def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: + """ + Atomically update the value at dst to the minimum of its current value and value. + + If memory_order is provided, it selects the memory-order semantic used by the underlying extern call; + allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally + to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument. + + Parameters: + dst (Buffer): Destination buffer/address to apply the atomic min. + value (PrimExpr): Value to compare/store atomically. + memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering. + return_prev (bool): If True, return the previous value; if False, return handle (default False). + + Returns: + PrimExpr: A handle expression representing the atomic-min operation, or the previous value if return_prev is True. + + Examples: + >>> # Basic atomic min operation + >>> min_val = T.Tensor([1], "int32", name="min_val") + >>> atomic_min(min_val, 10) + + >>> # Find minimum across threads + >>> @T.prim_func + >>> def find_min(data: T.Buffer, result: T.Buffer): + >>> for i in T.thread_binding(256, "threadIdx.x"): + >>> atomic_min(result, data[i]) + + >>> # Track minimum with previous value + >>> threshold = T.Tensor([1], "float32", name="threshold") + >>> old_min = atomic_min(threshold, 3.14, return_prev=True) + >>> # old_min contains the previous minimum value + + >>> # With relaxed memory ordering for performance + >>> atomic_min(min_val, 5, memory_order="relaxed") + """ + func_name = "AtomicMinRet" if return_prev else "AtomicMin" + return_type = dst.dtype if return_prev else "handle" + + if memory_order is None: + return T.call_extern(return_type, func_name, T.address_of(dst), value) + else: + return T.call_extern( + return_type, + func_name, + T.address_of(dst), + value, + _MEMORY_ORDER_ID_MAP[memory_order], + ) + + +def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr: + """ + Atomically add `value` into `dst`, returning a handle to the operation. + + Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`. + + Parameters: + dst (Buffer): Destination buffer/address to apply the atomic add. + value (PrimExpr): Value to add atomically. + memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering. + return_prev (bool): If True, return the previous value; if False, return handle (default False). + use_tma (bool): If True, use TMA (cp.reduce) to perform the atomic add. This is available only for sm90+ (default False). + + Returns: + PrimExpr: A handle representing the atomic addition operation, or the previous value if return_prev is True. + + Examples: + >>> # Basic atomic addition + >>> counter = T.Tensor([1], "int32", name="counter") + >>> atomic_add(counter, 1) # Increment counter by 1 + + >>> # Parallel sum reduction + >>> @T.prim_func + >>> def parallel_sum(data: T.Buffer, result: T.Buffer): + >>> for i in T.thread_binding(1024, "threadIdx.x"): + >>> atomic_add(result, data[i]) + + >>> # Get previous value for debugging + >>> old_value = atomic_add(counter, 5, return_prev=True) + >>> # old_value contains the value before adding 5 + + >>> # Tensor-to-tensor atomic add (tile-region based) + >>> src_tensor = T.Tensor([128, 64], "float32", name="src") + >>> dst_tensor = T.Tensor([128, 64], "float32", name="dst") + >>> atomic_add(dst_tensor, src_tensor) # Add entire tensors atomically + + >>> # With memory ordering for scalar operations + >>> atomic_add(counter, 10, memory_order="acquire") + + >>> # Accumulate gradients in training + >>> gradients = T.Tensor([1000], "float32", name="gradients") + >>> global_grad = T.Tensor([1000], "float32", name="global_grad") + >>> atomic_add(global_grad, gradients) + """ + + def get_extent(data): + """ + Return the inferred extent (shape) of a buffer-like object. + + If `data` is a Var bound to a let value, the let value is resolved before inspection. + Parameters: + data: A Var, Buffer, or BufferRegion to inspect. + + Returns: + The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined. + """ + if isinstance(data, Var) and T.has_let_value(data): + data = T.get_let_value(data) + if isinstance(data, Buffer): + return data.shape + elif isinstance(data, BufferRegion): + return [x.extent for x in data.region] + else: + return None + + src_extent = get_extent(value) + dst_extent = get_extent(dst) + + if dst_extent is None and src_extent is None: + func_name = "AtomicAddRet" if return_prev else "AtomicAdd" + return_type = dst.dtype if return_prev else "handle" + + # Pass destination by pointer to match device signature + if memory_order is None: + return T.call_extern(return_type, func_name, T.address_of(dst), value) + else: + return T.call_extern( + return_type, + func_name, + T.address_of(dst), + value, + _MEMORY_ORDER_ID_MAP[memory_order], + ) + + if isinstance(dst, Buffer) and isinstance(value, Buffer): + ir.assert_structural_equal(dst.shape, value.shape) + + assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) + dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) + src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) + + value = to_buffer_region(value, access_type="r", extents=src_extent) + dst = to_buffer_region(dst, access_type="w", extents=dst_extent) + + # Note: tile-region-based atomic operations don't support return_prev yet + # This would need to be implemented in the tile runtime + if return_prev: + raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations") + + if memory_order is None: + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) + else: + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order]) + + +def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: + """Perform an atomic addition operation with double-width operands. + + Args: + dst (Buffer): Destination buffer where the atomic addition will be performed + value (PrimExpr): Value to be atomically added (double-width) + return_prev (bool): If True, return the previous value; if False, return handle (default False) + + Returns: + PrimExpr: Handle to the double-width atomic addition operation, or the previous value if return_prev is True + + Examples: + >>> # Atomic addition with FP16 pairs + >>> half_dst = T.Tensor([2], "float16", name="half_dst") + >>> half_val = T.Tensor([2], "float16", name="half_val") + >>> atomic_addx2(half_dst, half_val) + + >>> # BF16 vectorized atomic add (requires CUDA Arch > 750) + >>> bf16_dst = T.Tensor([2], "bfloat16", name="bf16_dst") + >>> bf16_val = T.Tensor([2], "bfloat16", name="bf16_val") + >>> atomic_addx2(bf16_dst, bf16_val) + + >>> # Get previous paired values + >>> prev_values = atomic_addx2(half_dst, half_val, return_prev=True) + >>> # prev_values is a half2 containing the two previous FP16 values + + >>> # Efficient gradient accumulation for mixed precision training + >>> @T.prim_func + >>> def accumulate_fp16_gradients(grads: T.Buffer, global_grads: T.Buffer): + >>> for i in T.thread_binding(128, "threadIdx.x"): + >>> for j in range(0, grads.shape[1], 2): # Process in pairs + >>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2]) + """ + func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2" + return_type = dst.dtype if return_prev else "handle" + return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value)) + + +def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: + """Perform an atomic addition operation with quad-width operands. + + Args: + dst (Buffer): Destination buffer where the atomic addition will be performed + value (PrimExpr): Value to be atomically added (quad-width) + return_prev (bool): If True, return the previous value; if False, return handle (default False) + + Returns: + PrimExpr: Handle to the quad-width atomic addition operation, or the previous value if return_prev is True + + Examples: + >>> # Atomic addition with float4 (requires CUDA Arch >= 900) + >>> float4_dst = T.Tensor([4], "float32", name="float4_dst") + >>> float4_val = T.Tensor([4], "float32", name="float4_val") + >>> atomic_addx4(float4_dst, float4_val) + + >>> # Get previous float4 values + >>> prev_float4 = atomic_addx4(float4_dst, float4_val, return_prev=True) + >>> # prev_float4 is a float4 containing the four previous float32 values + + >>> # High-throughput gradient accumulation for large models + >>> @T.prim_func + >>> def accumulate_float4_gradients(grads: T.Buffer, global_grads: T.Buffer): + >>> for i in T.thread_binding(256, "threadIdx.x"): + >>> for j in range(0, grads.shape[1], 4): # Process 4 floats at once + >>> atomic_addx4(global_grads[i, j:j+4], grads[i, j:j+4]) + + >>> # Efficient RGBA pixel blending + >>> rgba_dst = T.Tensor([4], "float32", name="rgba_dst") # R, G, B, A channels + >>> rgba_add = T.Tensor([4], "float32", name="rgba_add") + >>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels + """ + func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4" + return_type = "float4" if "float" in str(dst.dtype).lower() else "handle" + return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value)) + + +def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: + """ + Load a value from the given buffer using the specified atomic memory ordering. + + Performs an atomic load from `src` and returns a PrimExpr representing the loaded value. + memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire", + "release", "acq_rel", or "seq_cst" (default). + Raises KeyError if an unknown memory_order is provided. + + Note: atomic_load always returns the loaded value, so no return_prev parameter is needed. + + Examples: + >>> # Basic atomic load + >>> shared_var = T.Tensor([1], "int32", name="shared_var") + >>> value = atomic_load(shared_var) + + >>> # Load with specific memory ordering + >>> value = atomic_load(shared_var, memory_order="acquire") + >>> # Ensures all subsequent memory operations happen after this load + + >>> # Relaxed load for performance-critical code + >>> value = atomic_load(shared_var, memory_order="relaxed") + + >>> # Producer-consumer pattern + >>> @T.prim_func + >>> def consumer(flag: T.Buffer, data: T.Buffer, result: T.Buffer): + >>> # Wait until producer sets flag + >>> while atomic_load(flag, memory_order="acquire") == 0: + >>> pass # Spin wait + >>> # Now safely read data + >>> result[0] = data[0] + + >>> # Load counter for statistics + >>> counter = T.Tensor([1], "int64", name="counter") + >>> current_count = atomic_load(counter, memory_order="relaxed") + """ + return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src), _MEMORY_ORDER_ID_MAP[memory_order]) + + +def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr: + """ + Perform an atomic store of `src` into `dst` with the given memory ordering. + + Parameters: + dst (Buffer): Destination buffer to store into. + src (PrimExpr): Value to store. + memory_order (str, optional): Memory ordering name; one of "relaxed", "consume", + "acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst". + The name is mapped to an internal numeric ID used by the underlying runtime. + + Returns: + PrimExpr: A handle representing the issued atomic store operation. + + Raises: + KeyError: If `memory_order` is not one of the supported names. + + Note: atomic_store doesn't return a previous value, so no return_prev parameter is needed. + + Examples: + >>> # Basic atomic store + >>> shared_var = T.Tensor([1], "int32", name="shared_var") + >>> atomic_store(shared_var, 42) + + >>> # Store with release ordering to publish data + >>> data = T.Tensor([1000], "float32", name="data") + >>> ready_flag = T.Tensor([1], "int32", name="ready_flag") + >>> # ... fill data ... + >>> atomic_store(ready_flag, 1, memory_order="release") + >>> # Ensures all previous writes are visible before flag is set + + >>> # Relaxed store for performance + >>> atomic_store(shared_var, 100, memory_order="relaxed") + + >>> # Producer-consumer synchronization + >>> @T.prim_func + >>> def producer(data: T.Buffer, flag: T.Buffer): + >>> data[0] = 3.14159 # Write data first + >>> atomic_store(flag, 1, memory_order="release") + >>> # Consumer can now safely read data after seeing flag == 1 + + >>> # Update configuration atomically + >>> config = T.Tensor([1], "int32", name="config") + >>> new_config = 0x12345678 + >>> atomic_store(config, new_config, memory_order="seq_cst") + + >>> # Thread-safe logging counter + >>> log_counter = T.Tensor([1], "int64", name="log_counter") + >>> atomic_store(log_counter, 0) # Reset counter atomically + """ + return T.call_extern("handle", "AtomicStore", T.address_of(dst), src, _MEMORY_ORDER_ID_MAP[memory_order]) diff --git a/tilelang/original/tilelang/language/builtin.py b/tilelang/original/tilelang/language/builtin.py new file mode 100644 index 0000000000000000000000000000000000000000..2932656ca4edd6c806d77b41adbf2f3e60beaa12 --- /dev/null +++ b/tilelang/original/tilelang/language/builtin.py @@ -0,0 +1,929 @@ +"""Builtin operations exposed on the TileLang language surface.""" + +from __future__ import annotations + +from tilelang import tvm as tvm +from tilelang.language import ptx_arrive_barrier, evaluate +from tilelang.language.kernel import get_thread_bindings, get_block_extents +from tilelang.utils.target import check_hip_availability +from tvm import DataType, tir +from tvm.runtime import convert +from typing import Any +from tvm.tir import PrimExpr, Var, Call, BufferLoad, BufferRegion + +_IS_HIP_AVAILABLE = check_hip_availability() + + +def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None: + """ + Normalize warp sizing arguments so both Python ints and PrimExpr values + are accepted uniformly. + """ + if value is None: + return None + if isinstance(value, PrimExpr): + return value + if isinstance(value, int): + return tir.IntImm("int32", value) + raise TypeError(f"Expect warp sizing argument to be int or PrimExpr, but got {type(value)}.") + + +def create_list_of_mbarrier(*args: Any) -> Call: + """ + Create a list of memory barrier handles. + + Parameters + ---------- + *args : list or Any + Either a single list of arguments, or multiple arguments directly. + + Returns + ------- + tvm.tir.Call + Handle to the created list of memory barriers. + + Raises + ------ + TypeError + If the input is not a list or variadic arguments. + + Examples + -------- + >>> create_list_of_mbarrier([128, 128]) + >>> create_list_of_mbarrier(128, 128) + """ + if len(args) == 1 and isinstance(args[0], list): + return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args[0]) + elif len(args) >= 1: + return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args) + else: + raise TypeError("create_list_of_mbarrier expects a list or one or more arguments.") + + +def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = None) -> PrimExpr: + """Explicitly load via CUDA read-only data cache. + + Prefer calling with a BufferLoad: `T.__ldg(x[i])` emits `__ldg(&x[i])` on CUDA. + On non-CUDA backends, falls back to a regular load. + + Args: + load_or_buf: A `BufferLoad` like `x[i]`, or a `Buffer`. + index: Optional index when passing a `Buffer` directly. + + Returns: + PrimExpr: The loaded value. + """ + if isinstance(load_or_buf, BufferLoad): + dtype = load_or_buf.dtype + return tir.call_intrin(str(dtype), tir.op.Op.get("tl.__ldg"), load_or_buf) + if isinstance(load_or_buf, tir.Buffer): + if index is None: + raise ValueError("T.__ldg(Buffer, index) requires an index when passing a Buffer.") + idx = index + if isinstance(index, (list, tuple)): + if len(index) != 1: + raise ValueError("T.__ldg currently supports 1D flattened indices.") + idx = index[0] + bl = BufferLoad(load_or_buf, [idx]) + return tir.call_intrin(str(load_or_buf.dtype), tir.op.Op.get("tl.__ldg"), bl) + raise TypeError("T.__ldg expects a BufferLoad or a Buffer.") + + +def get_mbarrier(*args): + """Retrieve a memory barrier operation. + + Args: + *args: Variable arguments to specify which memory barrier to retrieve + + Returns: + tir.Call: A handle to the requested memory barrier + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), *args) + + +def create_tma_descriptor(*args): + """Create a Tensor Memory Access (TMA) descriptor. + + Args: + *args: Variable arguments defining the TMA descriptor configuration + + Returns: + tir.Call: A handle to the created TMA descriptor + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.create_tma_descriptor"), *args) + + +def tma_load(*args): + """Perform a Tensor Memory Access (TMA) load operation. + + Args: + *args: Variable arguments specifying the TMA load parameters + + Returns: + tir.Call: A handle to the TMA load operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.tma_load"), *args) + + +def fence_proxy_async(*args): + """Create a fence for asynchronous proxy operations. + + Args: + *args: Variable arguments for fence configuration + + Returns: + tir.Call: A handle to the fence operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_proxy_async"), *args) + + +def tma_store_arrive(*args): + """Signal the arrival of a TMA store operation. + + Args: + *args: Variable arguments for the store arrival operation + + Returns: + tir.Call: A handle to the store arrive operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_arrive"), *args) + + +def tma_store_wait(*args): + """Wait for completion of TMA store operations. + + Args: + *args: Variable arguments specifying which store operations to wait for + + Returns: + tir.Call: A handle to the store wait operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args) + + +def set_max_nreg(reg_count: int, is_inc: int): + """Set the maximum number of registers to use. + Detailed Documentation: + https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg + + Args: + reg_count: int + The number of registers to allocate + is_inc: int + Whether to increment or decrement the register count + 0 if decrement, 1 if increment + + Returns: + tir.Call: A handle to the register setting operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), reg_count, is_inc) + + +def inc_max_nreg(reg_count: int): + """Increment the maximum number of registers to use.""" + return set_max_nreg(reg_count, 1) + + +def dec_max_nreg(reg_count: int): + """Decrement the maximum number of registers to use.""" + return set_max_nreg(reg_count, 0) + + +def annotate_producer_reg_dealloc(reg_count: int = 24): + """Annotate the producer reg dealloc.""" + return dec_max_nreg(reg_count) + + +def annotate_consumer_reg_alloc(reg_count: int = 240): + """Annotate the consumer reg alloc.""" + return inc_max_nreg(reg_count) + + +def no_set_max_nreg(): + """Disable the maximum register limit setting.""" + return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg")) + + +def disable_warp_group_reg_alloc(): + """Disable the warp group reg alloc.""" + return no_set_max_nreg() + + +def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call, parity: int | Var): + """Wait for memory barrier parity condition. + + Args: + mbarrier: Optional[int, PrimExpr] + The memory barrier to wait on + parity: Optional[int, Var] + The parity value to wait for + Examples: + .. code-block:: python + + # Wait for parity 0 on barrier 0 + T.mbarrier_wait_parity(0, 0) + + # Wait for parity value in variable ko on barrier 1 + T.mbarrier_wait_parity(1, ko) + + # Wait using barrier handle + barrier = T.get_mbarrier(0) + T.mbarrier_wait_parity(barrier, 1) + + # Common usage in pipelined kernels: + for ko in range(num_stages): + # Producer waits for consumer to finish previous iteration + T.mbarrier_wait_parity(1, ko ^ 1) + # Producer copies data + T.copy(A_global, A_shared) + # Producer signals data ready + T.mbarrier_arrive(0) + + # Consumer waits for producer data + T.mbarrier_wait_parity(0, ko) + # Consumer computes + T.gemm(A_shared, B_shared, C_local) + # Consumer signals completion + T.mbarrier_arrive(1) + Returns: + tir.Call: A handle to the barrier wait operation + """ + if isinstance(mbarrier, (tir.Call, tir.BufferLoad)): + mbarrier = mbarrier + elif isinstance(mbarrier, (tir.PrimExpr, int)): + mbarrier = get_mbarrier(mbarrier) + elif isinstance(mbarrier, tir.Buffer): + mbarrier = tir.BufferLoad(mbarrier, [0]) + else: + raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}") + return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity) + + +def mbarrier_arrive(mbarrier: int | PrimExpr | tir.Call): + """Arrive at memory barrier. + + Args: + mbarrier: Optional[int, PrimExpr] + The memory barrier to arrive at + """ + if isinstance(mbarrier, (tir.Call, tir.BufferLoad)): + mbarrier = mbarrier + elif isinstance(mbarrier, (tir.PrimExpr, int)): + mbarrier = get_mbarrier(mbarrier) + elif isinstance(mbarrier, tir.Buffer): + mbarrier = tir.BufferLoad(mbarrier, [0]) + else: + raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}") + return ptx_arrive_barrier(mbarrier) + + +def mbarrier_expect_tx(*args): + """Set expected transaction count for memory barrier. + + Args: + *args: Variable arguments specifying the expected transaction count + + Returns: + tir.Call: A handle to the barrier expectation operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args) + + +def warpgroup_arrive(): + """Signal warpgroup readiness for subsequent WGMMA operations. + + Returns: + tir.Call: A handle to the warpgroup arrive operation. + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_arrive")) + + +def warpgroup_commit_batch(): + """Commit the current warpgroup batch for WGMMA operations. + + Returns: + tir.Call: A handle to the warpgroup commit batch operation. + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_commit_batch")) + + +def warpgroup_wait(num_mma: int): + """Wait for completion of the specified warpgroup batch. + + Args: + num_mma: int + Identifier of the warpgroup MMA batch to wait on. + + Returns: + tir.Call: A handle to the warpgroup wait operation. + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) + + +def get_lane_idx( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: + """Return the logical lane index of the calling thread within a warp. + + Parameters + ---------- + warp_size : Optional[int, PrimExpr] + Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD. + + Example + ------- + >>> lane = T.get_lane_idx() + >>> custom_lane = T.get_lane_idx(64) # override warp size explicitly + + Implementation Notes + -------------------- + Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in + `src/tl_templates/cuda/intrin.h`, which computes the lane index from the + linear thread id using the provided `warp_size`. + """ + warp_size_expr = _normalize_index_arg(warp_size) + if warp_size_expr is None: + return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx")) + return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr) + + +def get_warp_idx_sync( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: + """Return the canonical warp index, assuming the warp's threads are converged. + + Parameters + ---------- + warp_size : Optional[int, PrimExpr] + Logical warp size used for the index calculation. + + Example + ------- + >>> warp = T.get_warp_idx_sync() + >>> custom_warp = T.get_warp_idx_sync(64) + + Implementation Notes + -------------------- + Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear + thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers. + """ + warp_size_expr = _normalize_index_arg(warp_size) + if warp_size_expr is None: + return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync")) + return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr) + + +def get_warp_idx( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: + """Return the canonical warp index without synchronizing the warp. + + Parameters + ---------- + warp_size : Optional[int, PrimExpr] + Logical warp size used for the index calculation. + + Example + ------- + >>> warp = T.get_warp_idx() + >>> custom_warp = T.get_warp_idx(64) + + Implementation Notes + -------------------- + Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear + thread id by the provided `warp_size` without requiring warp convergence. + """ + warp_size_expr = _normalize_index_arg(warp_size) + if warp_size_expr is None: + return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx")) + return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"), warp_size_expr) + + +def get_warp_group_idx( + warp_size: int | PrimExpr | None = None, + warps_per_group: int | PrimExpr | None = None, +) -> PrimExpr: + """Return the canonical warp group index for the calling thread. + + Parameters + ---------- + warp_size : Optional[int, PrimExpr] + Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD). + warps_per_group : Optional[int, PrimExpr] + Number of warps per warp-group. Defaults to 4 on NVIDIA architectures. + + Example + ------- + >>> group = T.get_warp_group_idx() + >>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group + + Implementation Notes + -------------------- + Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which + divides the block-linear thread id by `warp_size * warps_per_group`, + matching the canonical ordering while allowing architecture-specific overrides. + """ + warp_size_expr = _normalize_index_arg(warp_size) + warps_per_group_expr = _normalize_index_arg(warps_per_group) + args = [] + if warp_size_expr is not None: + args.append(warp_size_expr) + if warps_per_group_expr is not None: + if warp_size_expr is None: + raise ValueError("get_warp_group_idx expects `warp_size` when specifying `warps_per_group`.") + args.append(warps_per_group_expr) + return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) + + +def shuffle_elect(thread_extent: int) -> PrimExpr: + """Elect exactly one lane within a logical thread group. + + Parameters + ---------- + thread_extent : int + Size (in threads) of the group in which a single lane should be elected. + Passing 0 elects a single lane in the entire thread block. + + Example + ------- + >>> is_leader = T.shuffle_elect(64) + >>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0)) + + Implementation Notes + -------------------- + Lowered to the CUDA helper `tl::tl_shuffle_elect()` defined in + `src/tl_templates/cuda/intrin.h`, which relies on + `cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or + `__shfl_sync`) to pick one lane per group. + """ + return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) + + +def warpgroup_fence_operand( + buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None +): + """Insert a warpgroup fence for the destination accumulator registers. + + This prevents NVCC from sinking uses of accumulator fragments past the corresponding + WGMMA operations by issuing an empty inline assembly barrier on every register. + + Args: + buffer_or_ptr: Buffer | BufferLoad | BufferRegion | PrimExpr + A buffer representing the accumulator fragment, a buffer load/region + that identifies a starting element within the fragment, or a pointer expression + (e.g., tvm_access_ptr/address_of/typed Var). + offset: int | PrimExpr + Element offset from the start of the accumulator fragment. + num_regs: int | PrimExpr | None + Number of 32-bit registers to fence. If None and a Buffer is provided, it will be + derived from the buffer shape and dtype. + dtype: str | None + Data type string of the accumulator elements. When passing a buffer or + buffer-derived expression, dtype is inferred. It is required only when + passing a raw pointer expression that cannot be inferred. + + Returns: + tir.Call: A handle to the warpgroup fence operation. + """ + if isinstance(buffer_or_ptr, BufferLoad): + # Treat BufferLoad as a request to fence starting from the loaded element's address + buf = buffer_or_ptr.buffer + data_ptr = buf.data + inferred_dtype = buf.dtype + if dtype is not None and dtype != inferred_dtype: + raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.") + dtype = inferred_dtype + # Compute element offset from indices using strides if present, otherwise row-major + if len(buf.strides) == len(buf.shape) and len(buf.strides) > 0: + elem_off = 0 + for idx, stride in zip(buffer_or_ptr.indices, buf.strides): + elem_off = elem_off + idx * stride + else: + elem_off = 0 + stride_acc = 1 + for idx, dim in zip(reversed(buffer_or_ptr.indices), reversed(buf.shape)): + elem_off = elem_off + idx * stride_acc + stride_acc = stride_acc * dim + # Combine with user-provided offset + offset = elem_off + convert(offset) + if num_regs is None: + raise ValueError("num_regs must be provided when passing a BufferLoad.") + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.warpgroup_fence_operand"), + dtype, + data_ptr, + convert(offset), + convert(num_regs), + ) + ) + + if isinstance(buffer_or_ptr, tir.Buffer): + data_ptr = buffer_or_ptr.data + inferred_dtype = buffer_or_ptr.dtype + if dtype is not None and dtype != inferred_dtype: + raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.") + dtype = inferred_dtype + if num_regs is None: + total_elems = 1 + for dim in buffer_or_ptr.shape: + if isinstance(dim, tir.IntImm): + total_elems *= int(dim) + else: + raise ValueError("warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") + bits_per_elem = DataType(dtype).bits + num_regs = (total_elems * bits_per_elem + 31) // 32 + elif isinstance(buffer_or_ptr, BufferRegion): + buf = buffer_or_ptr.buffer + data_ptr = buf.data + inferred_dtype = buf.dtype + if dtype is not None and dtype != inferred_dtype: + raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.") + dtype = inferred_dtype + # Compute element offset from region min using strides if present, otherwise row-major + if len(buf.strides) == len(buf.shape) and len(buf.strides) > 0: + elem_off = 0 + for r, stride in zip(buffer_or_ptr.region, buf.strides): + elem_off = elem_off + r.min * stride + else: + elem_off = 0 + stride_acc = 1 + for r, dim in zip(reversed(buffer_or_ptr.region), reversed(buf.shape)): + elem_off = elem_off + r.min * stride_acc + stride_acc = stride_acc * dim + # Combine with user-provided offset + offset = elem_off + convert(offset) + # Try derive num_regs from region extents if fully static; otherwise require user input + if num_regs is None: + total_elems = 1 + static = True + for r in buffer_or_ptr.region: + if isinstance(r.extent, tir.IntImm): + total_elems *= int(r.extent) + else: + static = False + break + if static: + bits_per_elem = DataType(dtype).bits + num_regs = (total_elems * bits_per_elem + 31) // 32 + else: + raise ValueError("warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic.") + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.warpgroup_fence_operand"), + dtype, + data_ptr, + convert(offset), + convert(num_regs), + ) + ) + else: + data_ptr = buffer_or_ptr + # Try to infer dtype from common pointer expressions when not provided + if dtype is None: + inferred = None + # Case 1: Pointer from Buffer.access_ptr -> tir.builtin.tvm_access_ptr + if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()): + # args[0] is a type annotation call; its dtype carries the element dtype + inferred = str(data_ptr.args[0].dtype) + # Case 2: Pointer from tir.address_of(BufferLoad(...)) + elif isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.address_of()): + # args[0] should be a BufferLoad; its dtype is the element dtype + inferred = str(data_ptr.args[0].dtype) + # Case 3: Typed pointer Var with PrimType element (typed TIR) + elif hasattr(data_ptr, "type_annotation") and data_ptr.type_annotation is not None: + try: + elem_ty = getattr(data_ptr.type_annotation, "element_type", None) + if elem_ty is not None and hasattr(elem_ty, "dtype"): + inferred = str(elem_ty.dtype) + except Exception: + inferred = None + if inferred is None: + raise ValueError("dtype must be provided when passing a pointer expression and cannot be inferred.") + dtype = inferred + if num_regs is None: + raise ValueError("num_regs must be provided when passing a pointer expression.") + + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.warpgroup_fence_operand"), + dtype, + data_ptr, + convert(offset), + convert(num_regs), + ) + ) + + +def wait_wgmma(id: int): + """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. + + Args: + id: int + The id of the WGMMA operation to wait for + + Returns: + tir.Call: A handle to the WGMMA wait operation + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id) + + +def barrier_wait(barrier_id: int | PrimExpr | tir.Call, parity: int | Var | None = None): + """Wait for a memory barrier to complete. + + Args: + barrier_id: Optional[int, PrimExpr] + The memory barrier to wait on + parity: Optional[int, Var] + The parity value to wait for + Returns: + tir.Call: A handle to the barrier wait operation + Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1. + """ + return mbarrier_wait_parity(barrier_id, parity) + + +def barrier_arrive(barrier_id: int | PrimExpr | tir.Call): + """Arrive at a memory barrier. + + Args: + barrier_id: Optional[int, PrimExpr] + The memory barrier to arrive at + """ + return mbarrier_arrive(barrier_id) + + +def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): + """Perform a shuffle operation with XOR offset. + + Args: + value: Optional[int, PrimExpr] + The value to shuffle + offset: Optional[int, PrimExpr] + The offset for the shuffle operation + Returns: + tir.Call: A handle to the shuffle operation + """ + if _IS_HIP_AVAILABLE: + return tir.call_extern(value.dtype, "__shfl_xor", value, offset) + else: + return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xFFFFFFFF, value, offset) + + +def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): + """Perform a shuffle operation with down offset. + + Args: + value: Optional[int, PrimExpr] + The value to shuffle + """ + if _IS_HIP_AVAILABLE: + return tir.call_extern(value.dtype, "__shfl_down", value, offset) + else: + return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset) + + +def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): + """Perform a shuffle operation with up offset. + + Args: + value: Optional[int, PrimExpr] + The value to shuffle + """ + if _IS_HIP_AVAILABLE: + return tir.call_extern(value.dtype, "__shfl_up", value, offset) + else: + return tir.call_extern(value.dtype, "__shfl_up_sync", 0xFFFFFFFF, value, offset) + + +def sync_threads(barrier_id: int = None, arrive_count: int = None): + """Synchronize all threads in a block.""" + args = [] + if barrier_id is not None: + args.append(barrier_id) + if arrive_count is not None: + args.append(arrive_count) + return tir.call_intrin("int32", "tir.tvm_storage_sync", "shared", *args) + + +def sync_global(): + """Synchronize all threads in the entire grid.""" + tx, ty, tz = get_thread_bindings() + ex, ey, ez = get_block_extents() + print(tx, ty, tz, ex, ey, ez) + args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez] + return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args)) + + +def sync_grid(): + """Synchronize all threads in a grid.""" + return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) + + +def initialize_wgmma_descriptor( + descriptor: tir.Buffer, + start_address: PrimExpr, + layout_type_: int = 0, + leading_byte_offset: int = 0, + stride_byte_offset: int = 0, +) -> PrimExpr: + """Initialize a WGMMA/UTCMMA shared-memory descriptor.""" + + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): + raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") + + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + raise ValueError("Descriptor must be a 1D buffer of size 1.") + + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) + + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.initialize_wgmma_descriptor"), + descriptor, + start_address, + layout_type_, + int(leading_byte_offset), + int(stride_byte_offset), + ) + ) + + +def initialize_tcgen05_descriptor( + descriptor: tir.Buffer, + start_address: PrimExpr, + leading_byte_offset: int, + stride_byte_offset: int, + base_offset: int = 0, + leading_is_absolute: bool = False, + swizzle_mode: int = 0, +) -> PrimExpr: + """Initialize a TCGEN05 shared-memory descriptor.""" + + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): + raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") + + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + raise ValueError("Descriptor must be a 1D buffer of size 1.") + + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) + + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.initialize_tcgen05_descriptor"), + descriptor, + start_address, + int(leading_byte_offset), + int(stride_byte_offset), + int(base_offset), + tir.IntImm("int32", 1 if leading_is_absolute else 0), + int(swizzle_mode), + ) + ) + + +def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: + """ + Increase the offset of a memory descriptor. + + Parameters: + descriptor (PrimExpr): The memory descriptor to modify. + offset (PrimExpr): The offset value to increase. + + Returns: + PrimExpr: A handle representing the modified descriptor. + """ + if not isinstance(descriptor, (BufferLoad, tir.Buffer)): + raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") + + if isinstance(descriptor, tir.Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + raise ValueError("Descriptor must be a 1D buffer of size 1.") + + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) + + return evaluate(tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, offset)) + + +def loop_break(): + """Break out of the innermost loop.""" + return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break")) + + +def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): + """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.""" + return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) + + +def tcgen05_mma_arrive(mbar_ptr): + """Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer. + + Parameters + ---------- + mbar_ptr : PrimExpr + Pointer to the mbarrier object in shared memory (e.g., Barrier*). + """ + return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr) + + +def ptx_mma_sm70( + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, +): + """TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta). + + This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape + with FP16 inputs and FP16/FP32 accumulation. + + Parameters + ---------- + + shape : str + The shape of mma fragment (e.g., "m16n16k4"). + + A_layout : str + The layout of multiplicand fragment A ("row" or "col"). + + B_layout : str + The layout of multiplicand fragment B ("row" or "col"). + + A_dtype : str + The data type of multiplicand fragment A (typically "fp16"). + + B_dtype : str + The data type of multiplicand fragment B (typically "fp16"). + + C_dtype : str + The data type of accumulator fragment C ("fp16" or "fp32"). + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment B. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + Returns + ------- + call : PrimExpr + The call expression. + + Examples + -------- + >>> T.ptx_mma_sm70( + ... "float16", + ... "m16n16k4", + ... "row", + ... "col", + ... "fp16", + ... "fp16", + ... "fp16", + ... A_local.data, + ... 0, + ... B_local.data, + ... 0, + ... C_local.data, + ... 0, + ... ) + """ + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.ptx_mma_sm70"), + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + ) diff --git a/tilelang/original/tilelang/language/copy_op.py b/tilelang/original/tilelang/language/copy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0b55c410c75832cd4ec153373585644c2e07b067 --- /dev/null +++ b/tilelang/original/tilelang/language/copy_op.py @@ -0,0 +1,142 @@ +"""Copy operations exposed on the TileLang language surface.""" + +from __future__ import annotations +from typing import Literal +from tilelang import language as T +from tilelang.utils.language import ( + to_buffer_region, + get_buffer_region_from_load, + legalize_pairwise_extents, +) +from tvm import ir, tir + + +def copy( + src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + coalesced_width: int | None = None, + disable_tma: bool = False, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): + """Copy data between memory regions. + + Args: + src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region + dst (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Destination memory region + coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None. + + Raises: + TypeError: If copy extents cannot be deduced from arguments + + Returns: + tir.Call: A handle to the copy operation + + Range handling notes: + - Accepts `Buffer`/`BufferRegion`/`BufferLoad` on either side. Extents are + derived as follows: `Buffer -> shape`, `BufferRegion -> [r.extent]`, + `BufferLoad -> extents from its inferred/encoded region`. + - If both `src` and `dst` are scalar `BufferLoad` without region extents, + lowers to a direct store: `dst[...] = src`. + - If one side is missing extents, it is treated as all-ones with the other + side's rank to enable broadcasting. + - Extents are right-aligned and legalized via `legalize_pairwise_extents`: + per tail-dimension, equal keeps as-is, a `1` broadcasts to the other, + otherwise a conservative `tir.max` is used to remain safe for dynamic + shapes. + - The finalized extents are encoded with `tl.region` via `to_buffer_region` + and passed through to the backend; low-level loop construction and any + scope-specific decisions happen during lowering. + """ + if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): + ir.assert_structural_equal(src.shape, dst.shape) + + def get_extent(data): + if isinstance(data, tir.Var) and T.has_let_value(data): + data = T.get_let_value(data) + if isinstance(data, tir.Buffer): + return data.shape + elif isinstance(data, tir.BufferRegion): + return [x.extent for x in data.region] + elif isinstance(data, tir.BufferLoad): + region = get_buffer_region_from_load(data) + if region is None: + return None + return [x.extent for x in region.region] + else: + return None + + src_extent = get_extent(src) + dst_extent = get_extent(dst) + # Combine the nested if statements into a single if statement as suggested by SIM102 + if src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad): + # check if the case is like this: + # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes + # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] + return tir.BufferStore(dst.buffer, src, dst.indices) + + assert src_extent or dst_extent, "Can't deduce copy extents from args" + # Treat missing extent as length-matched ones to enable broadcasting. + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) + dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) + + # Align and broadcast extents from the right (tail) side. + src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) + + # Use legalized extents for src and dst respectively. + src = to_buffer_region(src, access_type="r", extents=src_extent) + dst = to_buffer_region(dst, access_type="w", extents=dst_extent) + + if coalesced_width is None: + coalesced_width = -1 # PrimExpr can not be None + if eviction_policy is None: + eviction_policy = 0 + else: + eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy) + + +def c2d_im2col( + img: tir.Buffer, + col: tir.Buffer, + nhw_step: tir.PrimExpr, + c_step: tir.PrimExpr, + kernel: int, + stride: int, + dilation: int, + pad: int, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): + """Perform im2col transformation for 2D convolution. + + Args: + img (tir.Buffer): Input image buffer + col (tir.Buffer): Output column buffer + nhw_step (tir.PrimExpr): Step size for batch and spatial dimensions + c_step (tir.PrimExpr): Step size for channel dimension + kernel (int): Kernel size + stride (int): Stride of the convolution + dilation (int): Dilation rate + pad (int): Padding size + + Returns: + tir.Call: A handle to the im2col operation + """ + if eviction_policy is None: + eviction_policy = 0 + else: + eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] + img_region = to_buffer_region(img, access_type="r") + col_region = to_buffer_region(col, access_type="w") + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.c2d_im2col"), + img_region, + col_region, + nhw_step, + c_step, + kernel, + stride, + dilation, + pad, + eviction_policy, + ) diff --git a/tilelang/original/tilelang/language/customize.py b/tilelang/original/tilelang/language/customize.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4e754f7cd3a845ce153b54427805438b15045f --- /dev/null +++ b/tilelang/original/tilelang/language/customize.py @@ -0,0 +1,75 @@ +"""Some customized operations frequently used in tensor programming, exposed on the TileLang language surface.""" + +from __future__ import annotations +import tilelang.language as T +from tvm.tir import PrimExpr, Buffer, op +from tilelang.utils.language import bits_product, prim_expr_equal +from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 + + +def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: + """Perform a 4-element dot product with accumulation (DP4A). + + Args: + A (Buffer): First input buffer + B (Buffer): Second input buffer + C (Buffer): Accumulation buffer + + Returns: + PrimExpr: Handle to the DP4A operation + """ + return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)) + + +def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: + """Clamps the input value dst between [min_val, max_val] + + Args: + dst: Input value to be clamped + min_val: Minimum value + max_val: Maximum value + + Returns: + Value clamped to the specified range + """ + dst = T.max(dst, min_val) # Ensure value is not less than minimum + dst = T.min(dst, max_val) # Ensure value is not greater than maximum + return dst + + +def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: + """Reshapes the input buffer to the specified shape. + + Args: + src (Buffer): Input buffer to be reshaped + shape (List[PrimExpr]): New shape for the buffer + + Returns: + Buffer: A new buffer view with the specified shape + """ + assert prim_expr_equal(bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)), ( + f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" + ) + return T.Tensor(shape, src.dtype, src.data) + + +def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer: + """Return a Tensor view of the input buffer with an optional new shape and dtype. + + If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy). + """ + if shape is None: + shape = src.shape + if dtype is None: + dtype = src.dtype + assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." + return T.Tensor(shape, dtype, src.data) + + +def loop_break(): + """Break out of the current loop. + + Returns: + tir.Call: A call to the `tl.loop_break` intrinsic. + """ + return T.call_intrin("handle", op.Op.get("tl.loop_break")) diff --git a/tilelang/original/tilelang/language/experimental/__init__.py b/tilelang/original/tilelang/language/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tilelang/original/tilelang/language/experimental/gemm_sp.py b/tilelang/original/tilelang/language/experimental/gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..1eaac6805e571e16f02d72d7334c0c2d6111e47d --- /dev/null +++ b/tilelang/original/tilelang/language/experimental/gemm_sp.py @@ -0,0 +1,223 @@ +"""The language interface for tl programs.""" + +from __future__ import annotations +from tilelang.tileop.base import GemmWarpPolicy +import tilelang.language as T +from tvm import tir +from tilelang.utils.language import ( + to_buffer_region, + retrieve_shape, + retrieve_stride, + retrieve_offset, + prim_expr_equal, +) +from tilelang.language.utils import ( + buffer_region_to_tile_region, +) + + +def gemm_sp( + A_sparse: tir.Buffer | tir.Var, + E: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +): + """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation. + + This function computes C = A @ B where A and B can optionally be transposed. + The operation supports various warp policies and accumulation modes. + + Args: + A_sparse (Union[tir.Buffer, tir.Var]): First input matrix dense values + E (Union[tir.Buffer, tir.Var]): First input matrix sparse metadata + B (Union[tir.Buffer, tir.Var]): Second input matrix + C (Union[tir.Buffer, tir.Var]): Output matrix for results + transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. + transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. + clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. + k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. + wg_wait (int, optional): Warp group wait count. Defaults to 0. + + Returns: + tir.Call: A handle to the GEMM operation + + Raises: + AssertionError: If the K dimensions of matrices A and B don't match + """ + + def legalize_arguments(arg: tir.Buffer | tir.Var): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A_sparse = legalize_arguments(A_sparse) + B = legalize_arguments(B) + C = legalize_arguments(C) + M = C.shape[0] + N = C.shape[1] + K_A = A_sparse.shape[0] if transpose_A else A_sparse.shape[1] + K_B = B.shape[1] if transpose_B else B.shape[0] + assert K_A * 2 == K_B, f"T.gemm_sp K shape check failed: K_A = {K_A}, K_B = {K_B}" + # Build tl.region descriptors for operands + A_arg = to_buffer_region(A_sparse, access_type="r") + E_arg = to_buffer_region(E, access_type="r") + B_arg = to_buffer_region(B, access_type="r") + C_arg = to_buffer_region(C, access_type="rw") + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.gemm_sp"), + A_arg, + E_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + M, + N, + K_B, + policy, + clear_accum, + k_pack, + wg_wait, + ) + + +# experimental currently, for fast compilation +def gemm_sp_v2( + A_sparse: tir.Buffer | tir.Var, + E: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + transpose_E: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +): + """Perform a General Matrix Multiplication (GEMM) operation. + + This function computes C = A @ B where A and B can optionally be transposed. + The operation supports various warp policies and accumulation modes. + + Args: + A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements + E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E + B (Union[tir.Buffer, tir.Var]): Second input matrix + C (Union[tir.Buffer, tir.Var]): Output matrix for results + transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. + transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. + clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. + k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. + wg_wait (int, optional): Warp group wait count. Defaults to 0. + + Returns: + tir.Call: A handle to the GEMM operation + + Raises: + AssertionError: If the K dimensions of matrices A and B don't match + """ + + def legalize_arguments(arg: tir.Buffer | tir.Var): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A_sparse = legalize_arguments(A_sparse) + E = legalize_arguments(E) + B = legalize_arguments(B) + C = legalize_arguments(C) + + A_region = to_buffer_region(A_sparse) + E_region = to_buffer_region(E) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) + + A_shape = retrieve_shape(A_sparse) + E_shape = retrieve_shape(E) # nolint: F841 + B_shape = retrieve_shape(B) + C_shape = retrieve_shape(C) + + A_stride = retrieve_stride(A_sparse) + B_stride = retrieve_stride(B) + + assert len(C_shape) == 2, "current only support C as a 2D tensor" + assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" + assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" + if len(A_shape) > 2: + for i in range(len(A_shape) - 2): + assert A_shape[i] == 1, ( + "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) + if len(B_shape) > 2: + for i in range(len(B_shape) - 2): + assert B_shape[i] == 1, ( + "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) + + M, N = C_shape + K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" + + stride_a = A_stride[-2] + stride_b = B_stride[-2] + + A_offset = retrieve_offset(A_sparse) + B_offset = retrieve_offset(B) + assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" + assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" + offset_a = A_offset[-1] + offset_b = B_offset[-1] + + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.gemm_sp_py"), + A_arg, + E_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + transpose_E, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + ) diff --git a/tilelang/original/tilelang/language/fastmath.py b/tilelang/original/tilelang/language/fastmath.py new file mode 100644 index 0000000000000000000000000000000000000000..c77fad34c5d96953ad28b422c279fecee704d8b5 --- /dev/null +++ b/tilelang/original/tilelang/language/fastmath.py @@ -0,0 +1,151 @@ +"""Fast math operations exposed on the TileLang language surface.""" + +from tvm import tir + + +def __log(x): + """Calculate log(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log"), x) + + +def __log2(x): + """Calculate log2(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log2"), x) + + +def __log10(x): + """Calculate log10(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log10"), x) + + +def __tan(x): + """Calculate tan(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__tan"), x) + + +def __cos(x): + """Calculate cos(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__cos"), x) + + +def __sin(x): + """Calculate sin(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__sin"), x) + + +def __exp10(x): + """Calculate 10**x with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__exp10"), x) + + +def __exp(x): + """Calculate 2**x with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__exp"), x) + + +__all__ = [ + "__log", # noqa: F401 + "__log2", # noqa: F401 + "__log10", # noqa: F401 + "__tan", # noqa: F401 + "__cos", # noqa: F401 + "__sin", # noqa: F401 + "__exp10", # noqa: F401 + "__exp", # noqa: F401 +] diff --git a/tilelang/original/tilelang/language/fill_op.py b/tilelang/original/tilelang/language/fill_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a093a84599009038c10a5a23ed530c743ad066e1 --- /dev/null +++ b/tilelang/original/tilelang/language/fill_op.py @@ -0,0 +1,62 @@ +"""Fill operations exposed on the TileLang language surface.""" + +from __future__ import annotations +from tvm import tir +from tilelang.language import has_let_value, get_let_value +from tilelang.utils.language import get_buffer_region_from_load, to_buffer_region + + +def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr): + """Fill a buffer or buffer region with a specified value. + + Args: + buffer: Either a TVM buffer or buffer region to be filled + value: The value to fill the buffer with + + Returns: + A TVM intrinsic call that performs the fill operation + """ + # Normalize Var with let value to its underlying object + if isinstance(buffer, tir.Var) and has_let_value(buffer): + buffer = get_let_value(buffer) + + # Build tl.region as argument + if isinstance(buffer, tir.Buffer): + extents = list(buffer.shape) + elif isinstance(buffer, tir.BufferRegion): + extents = [r.extent for r in buffer.region] + elif isinstance(buffer, tir.BufferLoad): + region = get_buffer_region_from_load(buffer) + if region is not None: + extents = [r.extent for r in region.region] + else: + extents = [tir.IntImm("int32", 1) for _ in buffer.indices] + else: + extents = [] + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value) + + +def clear(buffer: tir.Buffer | tir.Var): + """Clear a buffer by filling it with zeros. + + Args: + buffer: Either a TVM buffer or a variable that contains a buffer region + + Returns: + A fill operation that sets the buffer contents to zero + + Raises: + ValueError: If the buffer variable contains an invalid buffer region + """ + if isinstance(buffer, tir.Var) and has_let_value(buffer): + buffer_region = get_let_value(buffer) # Get the actual buffer region from variable + if isinstance(buffer_region, tir.BufferRegion): + return fill(buffer_region, 0) + elif isinstance(buffer_region, tir.BufferLoad): + region = get_buffer_region_from_load(buffer_region) + if region is None: + raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") + return fill(region, 0) + else: + raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") + return fill(buffer, 0) diff --git a/tilelang/original/tilelang/language/frame.py b/tilelang/original/tilelang/language/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..7e60f46ee98da3e3283744b39a575068e7dfed09 --- /dev/null +++ b/tilelang/original/tilelang/language/frame.py @@ -0,0 +1,209 @@ +"""Override the LetFrame to print a message when entering the frame.""" + +from __future__ import annotations +from tvm.ffi import register_object as _register_object +from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion +from tvm.ir import Range +from tvm import DataType +from tvm.script.ir_builder.tir.frame import TIRFrame +from collections import deque +import threading + + +class FrameStack: + """A stack-like container for managing TIR frame objects and their variable bindings. + + This class implements a stack data structure using a deque and maintains a mapping + of variables to their values. It provides methods for stack operations and variable + value lookups. + """ + + def __init__(self): + """Initialize an empty frame stack and variable mapping.""" + self._stack = deque() + self._var_value_map = {} + + def push(self, item): + """Push an item onto the stack and update variable mapping if applicable. + + Args: + item: The frame object to push onto the stack + """ + self._stack.append(item) + if hasattr(item, "var") and hasattr(item, "value"): + self._var_value_map[item.var] = item.value + + def pop(self): + """Remove and return the top item from the stack. + + Returns: + The top frame object from the stack + + Raises: + IndexError: If the stack is empty + """ + if self._stack: + item = self._stack.pop() + if hasattr(item, "var"): + self._var_value_map.pop(item.var, None) + return item + raise IndexError(f"{self.__class__.__name__} is empty") + + def get_value(self, var): + """Retrieve the value associated with a variable. + + Args: + var: The variable to look up + + Returns: + The value associated with the variable, or None if not found + """ + return self._var_value_map.get(var) + + def has_value(self, var): + """Check if a variable has an associated value. + + Args: + var: The variable to check + + Returns: + bool: True if the variable has an associated value, False otherwise + """ + return var in self._var_value_map + + def top(self): + """Return the top item of the stack without removing it. + + Returns: + The top frame object from the stack + + Raises: + IndexError: If the stack is empty + """ + if self._stack: + return self._stack[-1] + raise IndexError(f"{self.__class__.__name__} is empty") + + def __len__(self): + """Returns the number of items in the stack.""" + return len(self._stack) + + def __bool__(self): + """ + Allows truthy checks on the stack object itself, + e.g., 'if stack: ...' + """ + return bool(self._stack) + + +# Use thread local to store the stack +# This is to avoid the cross-thread interference +_local_let = threading.local() + + +def _get_let_stack() -> FrameStack: + if not hasattr(_local_let, "let_frame_stack"): + _local_let.let_frame_stack = FrameStack() + return _local_let.let_frame_stack + + +@_register_object("script.ir_builder.tir.LetFrame") +class LetFrame(TIRFrame): + """A TIR frame for let bindings that manages variable scope and value tracking. + + This frame type extends TIRFrame to provide variable binding functionality and + maintains a global stack of active bindings. + """ + + def __enter__(self) -> Var: + """Enter the let frame scope and process buffer loads. + + Returns: + Var: The variable bound in this frame + """ + super().__enter__() + if isinstance(self.value, BufferLoad): + indices = self.value.indices + is_block_load = False + for index in indices[:-1]: + if DataType(index.dtype).lanes > 1: + is_block_load = True + break + if is_block_load: + self.value = BufferRegion(self.value.buffer, [Range(x.base, x.lanes) for x in indices]) + + _get_let_stack().push(self) + return self.var + + def __exit__(self, ptype, value, trace): + """Exit the let frame scope and clean up the stack. + + Args: + ptype: Exception type if an exception occurred + value: Exception value if an exception occurred + trace: Exception traceback if an exception occurred + """ + stack = _get_let_stack() + if stack.top() is self: + stack.pop() + super().__exit__(ptype, value, trace) + + @classmethod + def Current(cls) -> LetFrame: + """Get the current (topmost) let frame. + + Returns: + LetFrame: The current let frame + + Raises: + IndexError: If there are no active let frames + """ + return _get_let_stack().top() + + @staticmethod + def get_value(var: Var): + """Get the value bound to a variable in any active frame. + + Args: + var (Var): The variable to look up + + Returns: + The value bound to the variable, or None if not found + """ + return _get_let_stack().get_value(var) + + @staticmethod + def has_value(var: Var) -> bool: + """Check if a variable has a binding in any active frame. + + Args: + var (Var): The variable to check + + Returns: + bool: True if the variable has a binding, False otherwise + """ + return _get_let_stack().has_value(var) + + +def has_let_value(var: Var) -> bool: + """Check if a variable has a binding in the current let frame stack. + + Args: + var (Var): The variable to check + + Returns: + bool: True if the variable has a binding, False otherwise + """ + return _get_let_stack().has_value(var) + + +def get_let_value(var: Var) -> PrimExpr | None: + """Get the value bound to a variable in the current let frame stack. + + Args: + var (Var): The variable to look up + + Returns: + Optional[PrimExpr]: The bound value if found, None otherwise + """ + return _get_let_stack().get_value(var) diff --git a/tilelang/original/tilelang/language/gemm_op.py b/tilelang/original/tilelang/language/gemm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e2bda2b9423b9cffcc17106fe50e55c664f60997 --- /dev/null +++ b/tilelang/original/tilelang/language/gemm_op.py @@ -0,0 +1,187 @@ +"""GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface.""" + +from __future__ import annotations +from tilelang.tileop.base import GemmWarpPolicy +import tilelang.language as T +from tvm import tir +from tilelang.utils.language import ( + to_buffer_region, + retrieve_shape, + retrieve_stride, + retrieve_offset, + prim_expr_equal, +) +from tilelang.language.utils import ( + buffer_region_to_tile_region, +) +from tilelang.env import env as _env + + +def _gemm_impl( + op_key: str, + A: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, + mbar: tir.Buffer | None = None, +): + """Shared GEMM implementation. + + Returns a call_intrin handle for the given op key. + """ + + def legalize_arguments(arg: tir.Buffer | tir.Var): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A = legalize_arguments(A) + B = legalize_arguments(B) + C = legalize_arguments(C) + mbar = legalize_arguments(mbar) if mbar is not None else None + + # Normalize A/B/C to BufferRegion for shape/stride/offset analysis + A_region = to_buffer_region(A) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) + + A_shape = retrieve_shape(A_region) + B_shape = retrieve_shape(B_region) + C_shape = retrieve_shape(C_region) + + A_stride = retrieve_stride(A_region) + B_stride = retrieve_stride(B_region) + + assert len(C_shape) == 2, "current only support C as a 2D tensor" + assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" + assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" + if len(A_shape) > 2: + for i in range(len(A_shape) - 2): + assert A_shape[i] == 1, ( + "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) + if len(B_shape) > 2: + for i in range(len(B_shape) - 2): + assert B_shape[i] == 1, ( + "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) + + M, N = C_shape + K = A_shape[-2] if transpose_A else A_shape[-1] + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert prim_expr_equal(K, K_B), f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}" + + stride_a = A_stride[-2] + stride_b = B_stride[-2] + + A_offset = retrieve_offset(A_region) + B_offset = retrieve_offset(B_region) + assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" + assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" + offset_a = A_offset[-1] + offset_b = B_offset[-1] + + mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, T.uint32) + C_coords = [r.min for r in C_region.region] + # Convert BufferRegion to tl.region calls for arguments + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + return tir.call_intrin( + "handle", + tir.op.Op.get(op_key), + A_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + mbar, + C_coords[0], + C_coords[1], + ) + + +# Public wrappers +def gemm_v1( + A: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, + mbar: tir.Buffer | None = None, +): + """GEMM v1: use op tl.gemm.""" + return _gemm_impl( + "tl.tileop.gemm", + A, + B, + C, + transpose_A, + transpose_B, + policy, + clear_accum, + k_pack, + wg_wait, + mbar, + ) + + +# experimental currently, for fast compilation +def gemm_v2( + A: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, + mbar: tir.Buffer | None = None, +): + """GEMM v2: use op tl.gemm_py.""" + return _gemm_impl( + "tl.tileop.gemm_py", + A, + B, + C, + transpose_A, + transpose_B, + policy, + clear_accum, + k_pack, + wg_wait, + mbar, + ) + + +# Default to v2; allow forcing v1 via environment variable +gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2 diff --git a/tilelang/original/tilelang/language/kernel.py b/tilelang/original/tilelang/language/kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..73f7ed949079a381c1c929dc0860eb133bbad2c5 --- /dev/null +++ b/tilelang/original/tilelang/language/kernel.py @@ -0,0 +1,350 @@ +"""Kernel launching language interface in TileLang.""" + +from __future__ import annotations +from collections import deque +from tvm import tir +from tvm.tir import Var +from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame +from tvm.ffi import register_object +from tilelang import _ffi_api +import threading + +# Ensure single-dimension kernel bindings can be unpacked like iterables. +# especially for issue https://github.com/tile-ai/tilelang/issues/830 +if not hasattr(Var, "__iter__"): + + def _var_iter(self): + yield self + + Var.__iter__ = _var_iter # type: ignore[attr-defined] + +if not hasattr(Var, "__len__"): + Var.__len__ = lambda self: 1 # type: ignore[attr-defined] + + +class FrameStack: + """ + A simple stack-like wrapper around a deque that provides + push, pop, and top methods for convenience. + """ + + def __init__(self): + self._stack = deque() + + def push(self, item): + """Pushes an item onto the top of the stack.""" + self._stack.append(item) + + def pop(self): + """ + Pops and returns the top of the stack, or returns None + if the stack is empty. + """ + if self._stack: + return self._stack.pop() + raise IndexError(f"{self.__class__.__name__} is empty") + + def top(self): + """ + Returns the item on the top of the stack without removing it, + or None if the stack is empty. + """ + if self._stack: + return self._stack[-1] + raise IndexError(f"{self.__class__.__name__} is empty") + + def size(self): + """Returns the number of items in the stack.""" + return len(self._stack) + + def __len__(self): + """Returns the number of items in the stack.""" + return len(self._stack) + + def __bool__(self): + """ + Allows truthy checks on the stack object itself, + e.g., 'if stack: ...' + """ + return bool(self._stack) + + +# Use thread local to store the stack +# This is to avoid the cross-thread interference +_local = threading.local() + + +def _get_current_stack() -> FrameStack: + if not hasattr(_local, "kernel_launch_frame_stack"): + _local.kernel_launch_frame_stack = FrameStack() + return _local.kernel_launch_frame_stack + + +def _normalize_bindings(bindings: list[Var]) -> Var | list[Var]: + """ + Return a bare Var when we only have a single binding so that users may write either + `with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`. + Otherwise, keep the list semantics for multi-dimensional launches. + """ + if len(bindings) == 1: + return bindings[0] + return bindings + + +@register_object("tl.KernelLaunchFrame") +class KernelLaunchFrame(TIRFrame): + """ + KernelLaunchFrame is a custom TIRFrame that manages block/thread indices + and handles the entry and exit of the kernel launch scope. + """ + + def __enter__(self) -> Var | list[Var]: + """ + Enters the KernelLaunchFrame scope and pushes this frame onto the stack. + Returns one Var if we detect exactly 5 frames (meaning there is a single + block dimension), or a list of Vars otherwise. + """ + super().__enter__() + _get_current_stack().push(self) + + last_block_frame = self.frames[-1] + assert isinstance(last_block_frame, BlockFrame), f"Last frame must be a block frame, got {last_block_frame}" + + maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False) + + if maybe_cpu: + # CPU kernel frame, return a list of for frame items. + return _normalize_bindings([frame.vars[0] for frame in self.frames[0:-1]]) + else: + # Otherwise, return a list of iter_var.var objects (excluding the last 4 frames). + # As 4 frames for threadIdx.x, threadIdx.y, threadIdx.z and block frame with attributes + return _normalize_bindings([frame.iter_var.var for frame in self.frames[0:-4]]) + + def __exit__(self, ptype, value, trace): + """ + Exits the KernelLaunchFrame scope and pops this frame from the stack, + but only if it's indeed the topmost frame. + """ + stack = _get_current_stack() + if stack.top() is self: + stack.pop() + super().__exit__(ptype, value, trace) + + @classmethod + def Current(cls) -> KernelLaunchFrame | None: + """ + Returns the topmost (current) KernelLaunchFrame from the stack if it exists, + or None if the stack is empty. + """ + stack = _get_current_stack() + return stack.top() if stack else None + + def get_block_extent(self, dim: int) -> int: + """ + Returns the block extent for the given dimension. + dim=0 corresponds to blockIdx.x, dim=1 to blockIdx.y, and dim=2 to blockIdx.z. + """ + iter_var = self.frames[dim].iter_var + return int(iter_var.dom.extent) + + def get_block_extents(self) -> list[int]: + """ + Returns the block extents for all three dimensions. + """ + return [self.get_block_extent(dim) for dim in range(3)] + + def get_thread_extent(self, dim: int) -> int: + """ + Returns the thread extent for the given dimension. + dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z. + """ + iter_var = self.frames[-4 + dim].iter_var + return int(iter_var.dom.extent) + + def get_thread_extents(self) -> list[int]: + """ + Returns the thread extents for all three dimensions. + """ + return [self.get_thread_extent(dim) for dim in range(3)] + + def get_thread_binding(self, dim: int = 0) -> Var: + """ + Returns the thread binding for the given dimension. + dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z. + """ + return self.frames[-4 + dim].iter_var.var + + def get_thread_bindings(self) -> list[Var]: + """ + Returns the thread binding for the given dimension. + dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z. + """ + return [frame.iter_var.var for frame in self.frames[-4:-1]] + + def get_num_threads(self) -> int: + """ + Returns the thread indices from the topmost frame. + """ + num_threads: int = 1 + for thread_dim in range(3): + num_threads *= self.get_thread_extent(thread_dim) + return num_threads + + def get_block_binding(self, dim: int = 0) -> Var: + """ + Returns the block binding for the given dimension. + dim=0 corresponds to blockIdx.x, dim=1 to blockIdx.y, and dim=2 to blockIdx.z. + """ + return self.frames[dim].iter_var.var + + def get_block_bindings(self) -> list[Var]: + """ + Returns all three block bindings. + """ + return [frame.iter_var.var for frame in self.frames[0:-4]] + + @property + def blocks(self) -> list[Var]: + """ + Returns the block indices from the topmost frame. + """ + return [frame.iter_var.var for frame in self.frames[0:-4]] + + @property + def threads(self) -> list[Var]: + """ + Returns the thread indices from the topmost frame. + """ + return [frame.iter_var.var for frame in self.frames[-4:]] + + @property + def num_threads(self) -> int: + """ + Returns the total number of threads. + """ + return self.get_num_threads() + + +def Kernel( + *blocks: list[tir.PrimExpr], + threads: int | list[int] | tuple | None = None, + is_cpu: bool = False, + prelude: str | None = None, +): + """Tools to quickly construct a GPU kernel launch frame. + + Parameters + ---------- + blocks : List[int] + A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z) + threads : int + A integer representing blockDim.x + Or a list of integers representing blockDim.(x|y|z) + if the value is -1, we skip the threadIdx.x binding. + is_cpu : bool + Whether the kernel is running on CPU. + Thus we will not bind threadIdx.x, threadIdx.y, threadIdx.z. + and blockIdx.x, blockIdx.y, blockIdx.z. + prelude : str + The import c code of the kernel, + will be injected before the generated kernel code. + + Returns + ------- + res : Tuple[frame.LaunchThreadFrame] + The result LaunchThreadFrame. + + Examples + -------- + Create a 1-D CUDA kernel launch and unpack the single block index: + + .. code-block:: python + + with T.Kernel(T.ceildiv(N, 128), threads=128) as bx: + # bx is the blockIdx.x binding (also iterable as (bx,)) + ... + + Launch a 2-D grid while requesting two thread dimensions: + + .. code-block:: python + + with T.Kernel(grid_x, grid_y, threads=(64, 2)) as (bx, by): + tx, ty = T.get_thread_bindings() + ... + + Emit a CPU kernel where thread bindings are skipped: + + .. code-block:: python + + with T.Kernel(loop_extent, is_cpu=True) as (i,): + ... + """ + attrs: dict = {} + + if not is_cpu and threads is None: + threads = 128 # default thread number + + if isinstance(threads, int): + threads = [threads, 1, 1] + elif isinstance(threads, list): + threads = threads + [1] * (3 - len(threads)) + elif isinstance(threads, tuple): + threads = list(threads) + [1] * (3 - len(threads)) + else: + assert is_cpu, "threads must be an integer or a list of integers" + + if is_cpu: + attrs["tilelang.is_cpu_kernel_frame"] = True + + if prelude is not None: + attrs["pragma_import_c"] = prelude + + return _ffi_api.KernelLaunch(blocks, threads, attrs) + + +def get_thread_binding(dim: int = 0) -> Var: + """Returns the thread binding for the given dimension.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_thread_binding(dim) + + +def get_thread_bindings() -> list[Var]: + """Returns all three thread bindings.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_thread_bindings() + + +def get_block_binding(dim: int = 0) -> Var: + """Returns the block binding for the given dimension.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_block_binding(dim) + + +def get_block_bindings() -> list[Var]: + """Returns all three block bindings.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_block_bindings() + + +def get_thread_extent(dim: int = 0) -> int: + """Returns the thread extent for the given dimension.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_thread_extent(dim) + + +def get_thread_extents() -> list[int]: + """Returns all three thread extents.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_thread_extents() + + +def get_block_extent(dim: int = 0) -> int: + """Returns the block extent for the given dimension.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_block_extent(dim) + + +def get_block_extents() -> list[int]: + """Returns all three block extents.""" + assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" + return KernelLaunchFrame.Current().get_block_extents() diff --git a/tilelang/original/tilelang/language/logical.py b/tilelang/original/tilelang/language/logical.py new file mode 100644 index 0000000000000000000000000000000000000000..66f0a2e2b56a2a60e82555492978b0f9b101c117 --- /dev/null +++ b/tilelang/original/tilelang/language/logical.py @@ -0,0 +1,76 @@ +"""Logical operations exposed on the TileLang language surface.""" + +from __future__ import annotations + +from tilelang import language as T +from tvm.tir import Buffer, BufferRegion, BufferLoad +from tvm import tir +from tilelang.utils.language import get_buffer_elems + + +def any_of(buffer: T.Tensor | BufferRegion): + """Check if any element in the buffer is true. + + Args: + buffer: Either a TVM buffer or buffer region to be checked + + Returns: + A TVM intrinsic call that performs the any operation + """ + return_type: str = "bool" + if isinstance(buffer, Buffer): + elems = get_buffer_elems(buffer) + return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer), elems) + elif isinstance(buffer, BufferRegion): + buffer, region = buffer.buffer, buffer.region + new_region = [] + extent = 1 + for i, r in enumerate(region): + extent = r.extent + if extent == 1: + new_region.append(r.min) + else: + # check the idx is the last dimension + if i != len(region) - 1: + raise ValueError( + "Only support the last dimension to be for T.any currently, please contact us if you need this feature" + ) + new_region.append(r.min) + buffer_load = BufferLoad(buffer, new_region) + return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), extent) + else: + raise ValueError(f"Invalid buffer type: {type(buffer)}") + + +def all_of(buffer: T.Tensor | BufferRegion): + """Check if all elements in the buffer are true. + + Args: + buffer: Either a TVM buffer or buffer region to be checked + + Returns: + A TVM intrinsic call that performs the any operation + """ + return_type: str = "bool" + if isinstance(buffer, Buffer): + elems = get_buffer_elems(buffer) + return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer), elems) + elif isinstance(buffer, BufferRegion): + buffer, region = buffer.buffer, buffer.region + new_region = [] + extent = 1 + for i, r in enumerate(region): + extent = r.extent + if extent == 1: + new_region.append(r.min) + else: + # check the idx is the last dimension + if i != len(region) - 1: + raise ValueError( + "Only support the last dimension to be for T.any currently, please contact us if you need this feature" + ) + new_region.append(r.min) + buffer_load = BufferLoad(buffer, new_region) + return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), extent) + else: + raise ValueError(f"Invalid buffer type: {type(buffer)}") diff --git a/tilelang/original/tilelang/language/loop.py b/tilelang/original/tilelang/language/loop.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbd4e9f8644f179f06774db713e3dc65494267a --- /dev/null +++ b/tilelang/original/tilelang/language/loop.py @@ -0,0 +1,186 @@ +"""Loop related language interfaces in TileLang.""" + +from __future__ import annotations +from typing import Any +from tvm import tir +from tvm.tir import IntImm +import tvm.script.ir_builder.tir as tb_tir +from .v2.builder import SerialForWithStep, UnrollForWithStep +from tilelang import _ffi_api +from tvm.script.ir_builder.tir import frame + + +def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): + """Tools to construct nested parallel for loop. + This can be used to create element-wise tensor expression. + + Parameters + ---------- + extents : PrimExpr + The extents of the iteration. + + coalesced_width : Optional[int] + The coalesced width of the parallel loop. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + annotations: dict[str, Any] = {} + if coalesced_width is not None: + annotations.update({"coalesced_width": coalesced_width}) + return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Persistent( + domain: list[tir.PrimExpr], + wave_size: tir.PrimExpr, + index: tir.PrimExpr, + group_size: tir.PrimExpr | None = 8, +): + """Tools to construct persistent for loop. + + Parameters + ---------- + domain : List[tir.PrimExpr] + The list of dominators. + wave_size : int + The wave size. + index : int + The tile index in one wave. + group_size : tir.PrimExpr + The group size. + """ + return _ffi_api.Persistent(domain, wave_size, index, group_size) + + +def Pipelined( + start: tir.PrimExpr, + stop: tir.PrimExpr = None, + num_stages: int = 0, + order: list[int] | None = None, + stage: list[int] | None = None, + sync: list[list[int]] | None = None, + group: list[list[int]] | None = None, +): + """Tools to construct pipelined for loop. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + stop : PrimExpr + The maximum value of iteration. + num_stages : int + The max number of buffer used between pipeline producers and consumers. + if num_stages is 0, pipeline will not be enabled. + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 + if order is None: + order = [] + if stage is None: + stage = [] + if sync is None: + sync = [] + if group is None: + group = [] + # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) + + +def serial( + start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None +) -> frame.ForFrame: + step_is_one = False + step_is_one |= isinstance(step, int) and step == 1 + step_is_one |= isinstance(step, IntImm) and step.value == 1 + if step is None or step_is_one: + return tb_tir.serial(start, stop, annotations=annotations) + else: + if stop is None: + stop = start + start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 + return SerialForWithStep(start, stop, step, annotations=annotations) + + +def unroll( + start: tir.PrimExpr, + stop: tir.PrimExpr | None = None, + step: tir.PrimExpr | None = None, + *, + explicit: bool = False, + unroll_factor: int | None = None, + annotations: dict[str, Any] | None = None, +) -> frame.ForFrame: + """The unrolled For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + step : PrimExpr + The step size of the iteration. + + explicit : bool + Whether to explicitly unroll the loop. + + unroll_factor : int + The unroll factor of the loop. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + + step_is_one = False + if stop is None: + stop = start + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 + + # Ensure annotations has {"pragma_unroll_explicit": True} by default + if annotations is None: + annotations = {"pragma_unroll_explicit": explicit} + else: + # Add "pragma_unroll_explicit": True if not already present + annotations = dict(annotations) + annotations.setdefault("pragma_unroll_explicit", explicit) + + if unroll_factor is not None: + # check pragma_unroll_explicit must be False + if annotations.get("pragma_unroll_explicit", True): + raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None") + annotations.update({"pragma_unroll_factor": unroll_factor}) + + if step is None or step_is_one: + return tb_tir.unroll(start, stop, annotations=annotations) + else: + return UnrollForWithStep(start, stop, step, annotations=annotations) + + +# "Serial" and "Unroll" are aliases of "T.serial" and "T.unroll". We use uppercase to emphasize that they are tile-level loops. + + +def Serial(*args, **kwargs): + return serial(*args, **kwargs) + + +def Unroll(*args, **kwargs): + return unroll(*args, **kwargs) diff --git a/tilelang/original/tilelang/language/math_intrinsics.py b/tilelang/original/tilelang/language/math_intrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfb617e5521ff3e9f5d5c244764f65c67086cb7 --- /dev/null +++ b/tilelang/original/tilelang/language/math_intrinsics.py @@ -0,0 +1,352 @@ +"""Common math intrinsics exposed on the TileLang language surface.""" + +from tvm import tir + + +def _validate_rounding_mode(rounding_mode): + """Validate that the rounding mode is one of the supported IEEE modes""" + valid_modes = {"rn", "rz", "ru", "rd"} + if isinstance(rounding_mode, str) and rounding_mode in valid_modes: + return + raise ValueError(f"Invalid rounding mode '{rounding_mode}'. Must be one of: {valid_modes}") + + +def __log(x): + """Calculate log(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log"), x) + + +def __log2(x): + """Calculate log2(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log2"), x) + + +def __log10(x): + """Calculate log10(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log10"), x) + + +def __tan(x): + """Calculate tan(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__tan"), x) + + +def __cos(x): + """Calculate cos(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__cos"), x) + + +def __sin(x): + """Calculate sin(x) with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__sin"), x) + + +def __exp10(x): + """Calculate 10**x with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__exp10"), x) + + +def __exp(x): + """Calculate 2**x with fast math + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__exp"), x) + + +# IEEE-compliant operations +def ieee_add(x, y, rounding_mode="rn"): + """IEEE-compliant addition with specified rounding mode + + Parameters + ---------- + x : PrimExpr + First operand. + y : PrimExpr + Second operand. + rounding_mode : str, optional + Rounding mode: 'rn' (round to nearest), 'rz' (round toward zero), + 'ru' (round toward positive infinity), 'rd' (round toward negative infinity). + Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result. + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + y = tir.convert(y) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_add"), x, y, rounding_mode) + + +def ieee_sub(x, y, rounding_mode="rn"): + """IEEE-compliant subtraction with specified rounding mode + + Parameters + ---------- + x : PrimExpr + First operand. + y : PrimExpr + Second operand. + rounding_mode : str, optional + Rounding mode: 'rn', 'rz', 'ru', 'rd'. Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result. + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + y = tir.convert(y) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_sub"), x, y, rounding_mode) + + +def ieee_mul(x, y, rounding_mode="rn"): + """IEEE-compliant multiplication with specified rounding mode + + Parameters + ---------- + x : PrimExpr + First operand. + y : PrimExpr + Second operand. + rounding_mode : str, optional + Rounding mode: 'rn', 'rz', 'ru', 'rd'. Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result. + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + y = tir.convert(y) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_mul"), x, y, rounding_mode) + + +def ieee_fmaf(x, y, z, rounding_mode="rn"): + """IEEE-compliant fused multiply-add with specified rounding mode + + Parameters + ---------- + x : PrimExpr + First operand. + y : PrimExpr + Second operand. + z : PrimExpr + Third operand (addend). + rounding_mode : str, optional + Rounding mode: 'rn', 'rz', 'ru', 'rd'. Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result of x * y + z. + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + y = tir.convert(y) + z = tir.convert(z) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_fmaf"), x, y, z, rounding_mode) + + +def ieee_frcp(x, rounding_mode="rn"): + """IEEE-compliant reciprocal with specified rounding mode + + Parameters + ---------- + x : PrimExpr + Input operand. + rounding_mode : str, optional + Rounding mode: 'rn', 'rz', 'ru', 'rd'. Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result of 1/x. + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_frcp"), x, rounding_mode) + + +def ieee_fsqrt(x, rounding_mode="rn"): + """IEEE-compliant square root with specified rounding mode + + Parameters + ---------- + x : PrimExpr + Input operand. + rounding_mode : str, optional + Rounding mode: 'rn', 'rz', 'ru', 'rd'. Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result of sqrt(x). + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_fsqrt"), x, rounding_mode) + + +def ieee_frsqrt(x): + """IEEE-compliant reciprocal square root (round to nearest only) + + Parameters + ---------- + x : PrimExpr + Input operand. + + Returns + ------- + result : PrimExpr + The result of 1/sqrt(x). + """ + x = tir.convert(x) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_frsqrt"), x) + + +def ieee_fdiv(x, y, rounding_mode="rn"): + """IEEE-compliant division with specified rounding mode + + Parameters + ---------- + x : PrimExpr + Dividend. + y : PrimExpr + Divisor. + rounding_mode : str, optional + Rounding mode: 'rn', 'rz', 'ru', 'rd'. Default is 'rn'. + + Returns + ------- + result : PrimExpr + The result of x/y. + """ + _validate_rounding_mode(rounding_mode) + x = tir.convert(x) + y = tir.convert(y) + rounding_mode = tir.convert(rounding_mode) + return tir.call_intrin(x.dtype, tir.op.Op.get("tl.ieee_fdiv"), x, y, rounding_mode) + + +__all__ = [ + "__log", # noqa: F401 + "__log2", # noqa: F401 + "__log10", # noqa: F401 + "__tan", # noqa: F401 + "__cos", # noqa: F401 + "__sin", # noqa: F401 + "__exp10", # noqa: F401 + "__exp", # noqa: F401 + "ieee_add", # noqa: F401 + "ieee_sub", # noqa: F401 + "ieee_mul", # noqa: F401 + "ieee_fmaf", # noqa: F401 + "ieee_frcp", # noqa: F401 + "ieee_fsqrt", # noqa: F401 + "ieee_frsqrt", # noqa: F401 + "ieee_fdiv", # noqa: F401 +] diff --git a/tilelang/original/tilelang/language/overrides/__init__.py b/tilelang/original/tilelang/language/overrides/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c900642fa89a02cab5907253c927838817831ed5 --- /dev/null +++ b/tilelang/original/tilelang/language/overrides/__init__.py @@ -0,0 +1,8 @@ +"""TileLang-specific runtime overrides. + +Importing this package registers custom handlers that extend or override +behavior from upstream TVMScript for TileLang semantics. +""" + +# Register parser overrides upon import. +from . import parser # noqa: F401 diff --git a/tilelang/original/tilelang/language/overrides/parser.py b/tilelang/original/tilelang/language/overrides/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..0b2fcc44f965c08207283e35966607ef69405c25 --- /dev/null +++ b/tilelang/original/tilelang/language/overrides/parser.py @@ -0,0 +1,155 @@ +"""TVMScript parser overrides tailored for TileLang.""" + +from functools import partial + +from tvm.script.ir_builder import tir as T +from tvm.script.parser._core import dispatch, doc +from tvm.tir import BufferLoad, Var + +from tvm.script.parser.tir import parser as tvm_tir_parser + + +def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]: + """Return the span (lineno, col, end_lineno, end_col) for a doc node.""" + return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) + + +# Original implementation located at +# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_assign). +@dispatch.register(token="tir", type_name="Assign") +def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=unused-argument + """Override `Assign` to support chained writes and `local.var` buffers.""" + if not node.targets: + self.report_error(node, "Assignment must have at least one target.") + + if isinstance(node.value, doc.Subscript): + check_slices = [] + if isinstance(node.value.slice, doc.Slice): + check_slices = [node.value.slice] + elif isinstance(node.value.slice, doc.Tuple): + for part in node.value.slice.elts: + if isinstance(part, doc.Slice): + check_slices.append(part) + for slice_node in check_slices: + if not slice_node.step and slice_node.upper and slice_node.lower: + slice_node.step = doc.Constant( + 1, + None, + 1, + 1, + slice_node.upper.lineno, + slice_node.upper.end_col_offset + 1, + slice_node.upper.lineno, + slice_node.upper.end_col_offset + 2, + ) + + rhs = self.eval_expr(node.value) + for lhs in node.targets: + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [self.eval_expr(index) for index in lhs.slice.elts] + else: + indices = self.eval_expr(lhs.slice) + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + continue + + if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): + load_ctx = doc.Load() + store_ctx = doc.Store() + lhs.ctx = load_ctx + lhs_value = self.eval_expr(lhs) + lhs.ctx = store_ctx + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): + T.buffer_store(lhs_value.buffer, rhs, indices=[0]) + continue + + self.eval_assign(target=lhs, source=rhs, bind_value=tvm_tir_parser.bind_assign_value) + + +# Original implementation located at +# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_aug_assign). +@dispatch.register(token="tir", type_name="AugAssign") +def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: disable=unused-argument + """Override `AugAssign` to support writes into `local.var` buffers.""" + lhs_pos = _get_node_span(node.target) + rhs_pos = _get_node_span(node.value) + + node.target.ctx = doc.Load() + with self.var_table.with_frame(): + lhs_name = "__tvm_tmp_value_aug_assign_lhs" + rhs_name = "__tvm_tmp_value_aug_assign_rhs" + lhs_expr = self.eval_expr(node.target) + rhs_expr = self.eval_expr(node.value) + self.var_table.add(lhs_name, lhs_expr) + self.var_table.add(rhs_name, rhs_expr) + op = doc.BinOp( + doc.Name(lhs_name, doc.Load(), *lhs_pos), + node.op, + doc.Name(rhs_name, doc.Load(), *rhs_pos), + *lhs_pos, + ) + rhs = self.eval_expr(op) + + lhs = node.target + lhs.ctx = doc.Store() + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [self.eval_expr(index) for index in lhs.slice.elts] + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + return + + if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): + load_ctx = doc.Load() + store_ctx = doc.Store() + lhs.ctx = load_ctx + lhs_value = self.eval_expr(lhs) + lhs.ctx = store_ctx + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): + T.buffer_store(lhs_value.buffer, rhs, indices=[0]) + return + + self.eval_assign(target=lhs, source=rhs, bind_value=tvm_tir_parser.bind_assign_value) + + +# Original implementation located at +# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_ann_assign). +@dispatch.register(token="tir", type_name="AnnAssign") +def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: disable=unused-argument + """Override `AnnAssign` to support writes into `local.var` buffers.""" + lhs = node.target + rhs = self.eval_expr(node.value) + ann_var = self.visit_tvm_annotation(node.annotation) + if not isinstance(ann_var, Var): + self.report_error(node.annotation, "Annotation should be Var") + + if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): + load_ctx = doc.Load() + store_ctx = doc.Store() + lhs.ctx = load_ctx + lhs_value = self.eval_expr(lhs) + lhs.ctx = store_ctx + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): + T.buffer_store(lhs_value.buffer, rhs, indices=[0]) + return + + self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value) + frame = T.LetStmt(rhs, var=ann_var) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() diff --git a/tilelang/original/tilelang/language/parser/__init__.py b/tilelang/original/tilelang/language/parser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a39e436cb0671f95536323f19367d61dfb0b1cda --- /dev/null +++ b/tilelang/original/tilelang/language/parser/__init__.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +# ruff: noqa +"""The tir parser""" + +from typing import TYPE_CHECKING + +from ..ast import * # pylint: disable=redefined-builtin +from ..ast import ir as _tir +from . import operation as _operation +from . import parser as _parser +from .entry import Buffer, Ptr + +if TYPE_CHECKING: + # pylint: disable=invalid-name + # Define prim_func and make it type check as static method + # so most tvmscript won't trigger pylint error here. + prim_func = staticmethod +else: + from .entry import macro, prim_func + +__all__ = _tir.__all__ + ["Buffer", "Ptr", "bool", "prim_func", "macro"] diff --git a/tilelang/original/tilelang/language/parser/entry.py b/tilelang/original/tilelang/language/parser/entry.py new file mode 100644 index 0000000000000000000000000000000000000000..53316d8c28388422ca4b139080b081e4fcb93210 --- /dev/null +++ b/tilelang/original/tilelang/language/parser/entry.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +# ruff: noqa +"""The entry point of TVM parser for tir.""" + +import inspect +from typing import Callable, Optional, Union + +from tvm.ir.base import deprecated +from tvm.tir import Buffer, PrimFunc + +from ..ast import buffer, ptr +from tvm.script.parser._core import parse, scan_macro, utils +from tvm.script.parser.core.parser import Parser, ScriptMacro + + +def prim_func(func: Optional[Callable] = None, private: bool = False, check_well_formed=True) -> Union[PrimFunc, Callable]: + """The parsing method for tir prim func, by using `@prim_func` as decorator. + + Parameters + ---------- + func : Callable + The function to be parsed as prim func. + (Listed as optional to allow the decorator to be used + without arguments, like `@prim_func`, + or with an argument, `@prim_func(private=True)`) + + private : bool, optional + Whether the function should be treated as private. + A private function has no global symbol attribute; + if the function is not private, it will have a global symbol + matching the function name. + + Returns + ------- + res : Union[PrimFunc, Callable] + The parsed tir prim func. + """ + # pylint: disable=unused-argument + # (private will be used in the parser, but not immediately) + + # need to capture this var outside the wrapper because the wrapper + # adds to the stack + outer_stack = inspect.stack() + + def decorator_wrapper(func): + if not inspect.isfunction(func): + raise TypeError(f"Expect a function, but got: {func}") + if utils.is_defined_in_class(outer_stack, func): + return func + f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed) + setattr(f, "__name__", func.__name__) + return f + + if func is not None: + # no optional args given => use wrapper directly + return decorator_wrapper(func) + else: + # if there is an optional arg given, return a new decorator + # that will then be invoked + setattr(decorator_wrapper, "dispatch_token", "tir") + return decorator_wrapper + + +setattr(prim_func, "dispatch_token", "tir") + +# Semantics of TIR macros: +# - Function that is decorated with @T.macro can have any parameters that +# follow Python syntax, i.e. positional, keyword, etc. Type annotations +# are not required, but are allowed. +# - Macro use follows the same syntax as a function call. +# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into +# the body of the macro, and the body with the substituted values is then +# inserted at the point where the call to the macro is located. + + +class TIRMacro(ScriptMacro): + """Specialization of the ScriptMacro class for TIR.""" + + def parse_macro(self, parser: Parser) -> None: + macro_def = self.get_macro_def() + parser.visit_body(macro_def.body) + + +def macro(*args, hygienic: bool = True) -> Callable: + """Decorator for macro definitions. + + Parameters + ---------- + hygienic: bool + Specifies whether the macro is hygienic or not. + A macro is hygienic if all symbols used in the macro's body are resolved + to values from the location of the macro definition. A non-hygienic macro + will have its symbols resolved to values at the time of the macro's use. + + Example: + ``` + import tvm + from tvm.script import tir as T + + x_value = 128 + + @T.macro(hygienic=True) + def static_capture(A, B): + B[()] = A[x_value] ### x_value binds to 128 + + @T.macro(hygienic=False) + def dynamic_capture(A, B): + B[()] = A[x_value] ### x_value will bind at the time of use + + + @T.prim_func + def use1(A: T.Tensor((1024,), "int32"), B: T.Tensor((), "int32")) -> None: + for x_value in T.serial(10): + static_capture(A, B) ### Produces B[()] = A[128] + + @T.prim_func + def use2(A: T.Tensor((1024,), "int32"), B: T.Tensor((), "int32")) -> None: + for x_value in T.serial(10): + dynamic_capture(A, B) ### Produces B[()] = A[x_value] + ``` + """ + + def _decorator(func: Callable) -> TIRMacro: + source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) + obj = TIRMacro(source, closure_vars, func, hygienic) + obj.__name__ = func.__name__ + return obj + + if len(args) == 0: + return _decorator + if len(args) == 1 and inspect.isfunction(args[0]): + return _decorator(args[0]) + + raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") + + +class BufferProxy: + """Buffer proxy class for constructing tir buffer.""" + + def __call__( + self, + shape, + dtype=T.float32, + data=None, + strides=None, + elem_offset=None, + scope="global", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Buffer: + return buffer( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + @deprecated("T.Tensor[...]", "T.Tensor(...)") + def __getitem__(self, keys) -> Buffer: + if not isinstance(keys, tuple): + return self(keys) + if len(keys) >= 2 and not isinstance(keys[1], str): + return self(keys) + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member + + +class PtrProxy: + """Ptr proxy class for constructing tir pointer.""" + + @deprecated("T.Ptr(...)", "T.handle(...)") + def __call__(self, dtype, storage_scope="global"): + if callable(dtype): + dtype = dtype().dtype + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member + + @deprecated("T.Ptr[...]", "T.handle(...)") + def __getitem__(self, keys): + if not isinstance(keys, tuple): + return self(keys) + return self(*keys) + + +Buffer = BufferProxy() # pylint: disable=invalid-name +Ptr = PtrProxy() # pylint: disable=invalid-name diff --git a/tilelang/original/tilelang/language/parser/operation.py b/tilelang/original/tilelang/language/parser/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..473da43275a6252588288f3bafda1572d314b978 --- /dev/null +++ b/tilelang/original/tilelang/language/parser/operation.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +"""The tir expression operation registration""" + +from tvm import tir +from tvm.ffi.runtime_ctypes import DataType, DataTypeCode +from tvm.tir import IntImm +from tvm.tir.expr import FloatImm + +from tvm.script.parser._core import OpMethod, doc, register_op + + +def _register_expr_op(ty: type): # pylint: disable=invalid-name + ty._dispatch_type = ty # pylint: disable=protected-access + + def _and(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + return a & b + else: + return tir.And(a, b) + + def _or(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + return a | b + else: + return tir.Or(a, b) + + def _get_type_str(dtype: str): + if DataType(dtype).lanes == 1: + return dtype + index = dtype.find("x") + return dtype[0:index] + + def _auto_broadcast(a, b, op): + if isinstance(a, int): + if hasattr(b, "dtype"): + if DataType(b.dtype).type_code == DataTypeCode.INT or DataType(b.dtype).type_code == DataTypeCode.UINT: + a = IntImm(_get_type_str(b.dtype), a) + elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: + a = FloatImm(_get_type_str(b.dtype), a) + elif isinstance(b, float): + a = FloatImm("float32", a) + else: + a = IntImm("int32", a) + elif isinstance(a, float): + if DataType(b.dtype).type_code == DataTypeCode.FLOAT: + a = FloatImm(_get_type_str(b.dtype), a) + else: + a = FloatImm("float32", a) + + assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr." + if isinstance(b, int): + if DataType(a.dtype).type_code == DataTypeCode.INT or DataType(a.dtype).type_code == DataTypeCode.UINT: + b = IntImm(_get_type_str(a.dtype), b) + elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: + b = FloatImm(_get_type_str(a.dtype), b) + elif isinstance(b, float): + b = FloatImm(_get_type_str(a.dtype), b) + + if DataType(a.dtype).lanes == DataType(b.dtype).lanes: + return op(a, b) + elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: + broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes) + return op(broadcast_a, b) + elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: + broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes) + return op(a, broadcast_b) + else: + raise TypeError("do not know how to deal with it.") + + def _eq(a, b): + return _auto_broadcast(a, b, tir.EQ) + + def _ne(a, b): + return _auto_broadcast(a, b, tir.NE) + + def _lt(a, b): + return _auto_broadcast(a, b, tir.LT) + + def _le(a, b): + return _auto_broadcast(a, b, tir.LE) + + def _gt(a, b): + return _auto_broadcast(a, b, tir.GT) + + def _ge(a, b): + return _auto_broadcast(a, b, tir.GE) + + def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name + register_op(ty, op, i)(m) + + for i in [0, 1]: + # Case 1. binop + # doc.Add <-- is overloaded + # doc.Sub <-- is overloaded + # doc.Mult <-- is overloaded + # doc.Div <-- is overloaded + # doc.FloorDiv <-- is overloaded + # doc.Mod <-- is overloaded + # doc.LShift <-- is overloaded + # doc.RShift <-- is overloaded + # doc.BitOr <-- is overloaded + # doc.BitXor <-- is overloaded + # doc.BitAnd <-- is overloaded + # doc.MatMult <-- not implemented + # doc.Pow <-- not implemented + # Case 2. cmpop + r(doc.Eq, i, _eq) + r(doc.NotEq, i, _ne) + r(doc.Lt, i, _lt) + r(doc.LtE, i, _le) + r(doc.Gt, i, _gt) + r(doc.GtE, i, _ge) + # doc.Is <-- not implemented + # doc.IsNot <-- not implemented + # doc.In <-- not implemented + # doc.NotIn <-- not implemented + # Case 3. boolop + r(doc.And, i, _and) + r(doc.Or, i, _or) + for i in [0]: + # Case 4. unaryop + # doc.Invert <-- is overloaded + r(doc.Not, i, tir.Not) + # doc.UAdd <-- is overloaded + # doc.USub <-- is overloaded + + +_register_expr_op(tir.PrimExpr) +_register_expr_op(tir.IterVar) diff --git a/tilelang/original/tilelang/language/parser/parser.py b/tilelang/original/tilelang/language/parser/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4cac0ad74f56de23c3edd2145fae444bef281423 --- /dev/null +++ b/tilelang/original/tilelang/language/parser/parser.py @@ -0,0 +1,585 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This file is modified from the original version, +# which is part of the TVM project (https://tvm.apache.org/). +# ruff: noqa +"""The base parser for tir""" + +import contextlib +from functools import partial +from typing import Any + +import tvm +from tvm.ir import GlobalVar, PrimType +from tvm.tir import Buffer, IterVar, PrimExpr, Var + +from tvm.script.ir_builder import ir as I +from tvm.script.ir_builder import tir as T + +# May rewrite some register functions +# if we use our own registration +# from .. import ast as T + +from tvm.script.ir_builder.base import IRBuilder +from tvm.script.ir_builder.base import IRBuilderFrame as Frame +from tvm.script.parser._core import Parser, dispatch, doc + + +def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + """Value binding methods when parsing with statement. + e.g. binding i, j, k with T.grid(128, 128, 128), when parsing + with T.grid(128, 128, 18) as i, j, k. + + Parameters + ---------- + self : Parser + The current parser. + + node : doc.expr + The doc AST expression node for error reporting. + + var_name : str + The variable name. + + value : Any + The value to be bound with. + + Returns + ------- + res : Any + The bound value. + """ + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, (Buffer, Var)): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement") + raise NotImplementedError + + +def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + """Value binding methods when parsing for statement. + e.g. binding i, j, k with T.grid(128, 128, 128), when parsing + for i, j, k in T.grid(128, 128, 128). + + Parameters + ---------- + self : Parser + The current parser. + + node : doc.expr + The doc AST expression node for error reporting. + + var_name : str + The variable name. + + value : Any + The value to be bound with. + + Returns + ------- + res : Any + The bound value. + """ + if isinstance(value, (list, tuple, tvm.ir.Array)): + for i, v in enumerate(value): + bind_for_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Var): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement") + raise NotImplementedError + + +def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + """Value binding methods when parsing assign statement. + e.g. binding vi, vj, vk with T.axis.remap("SSR", [i, j, k]), when parsing + vi, vj, vk = T.axis.remap("SSR", [i, j, k]). + + Parameters + ---------- + self : Parser + The current parser. + + node : doc.expr + The doc AST expression node for error reporting. + + var_name : str + The variable name. + + value : Any + The value to be bound with. + + Returns + ------- + res : Any + The bound value. + """ + if isinstance(value, T.meta_var): + return value.value + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_assign_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Frame): + value.add_callback(partial(value.__exit__, None, None, None)) + res = value.__enter__() + IRBuilder.name(var_name, res) + return res + elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and not self.var_table.exist(value)): + IRBuilder.name(var_name, value) + return value + else: + value = tvm.runtime.convert(value) + frame = T.LetStmt(value) + var = frame.var + IRBuilder.name(var_name, var) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + return var + + +def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool: + """ + Check the value of given annotation (argument name) in the prim_func decorator. + Returns the value of the annotation if present, otherwise giving the default value. + """ + # look for the named argument in the prim_func decorator + for dec in node.decorator_list: + if not isinstance(dec, doc.Call) or dec.func.attr != "prim_func": + continue + for keyword in dec.keywords: + if keyword.arg == annotation: + return keyword.value.value + return default + + +@dispatch.register(token="tir", type_name="For") +def visit_for(self: Parser, node: doc.For) -> None: + """The for visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.For + The doc AST for node. + """ + for_frame = self.eval_expr(node.iter) + if not isinstance(for_frame, T.frame.ForFrame): + self.report_error( + node.iter, + "Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + ) + with self.var_table.with_frame(): + with for_frame as iters: + self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="While") +def visit_while(self: Parser, node: doc.While) -> None: + """The while visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.While + The doc AST while node. + """ + with self.var_table.with_frame(): + cond = self.eval_expr(node.test) + with T.While(cond): + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + """The assign visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Assign + The doc AST assign node. + """ + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + + if isinstance(node.value, doc.Subscript): + check_slices = [] + if isinstance(node.value.slice, doc.Slice): + check_slices = [node.value.slice] + elif isinstance(node.value.slice, doc.Tuple): + for p in node.value.slice.elts: + if isinstance(p, doc.Slice): + check_slices.append(p) + for s in check_slices: + if not s.step and s.upper and s.lower: + s.step = doc.Constant( + 1, + None, + 1, + 1, + s.upper.lineno, + s.upper.end_col_offset + 1, + s.upper.lineno, + s.upper.end_col_offset + 2, + ) + + rhs = self.eval_expr(node.value) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = self.eval_expr(lhs.slice) + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AugAssign") +def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: + """The augmented assign visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.AugAssign + The doc AST augmented assign node. + """ + lhs_pos = ( + node.target.lineno, + node.target.col_offset, + node.target.end_lineno, + node.target.end_col_offset, + ) + rhs_pos = ( + node.value.lineno, + node.value.col_offset, + node.value.end_lineno, + node.value.end_col_offset, + ) + node.target.ctx = doc.Load(*lhs_pos) + with self.var_table.with_frame(): + lhs_name = "__tvm_tmp_value_aug_assign_lhs" + rhs_name = "__tvm_tmp_value_aug_assign_rhs" + lhs_expr = self.eval_expr(node.target) + rhs_expr = self.eval_expr(node.value) + self.var_table.add(lhs_name, lhs_expr) + self.var_table.add(rhs_name, rhs_expr) + op = doc.BinOp( + doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos), + node.op, + doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos), + *lhs_pos, + ) + rhs = self.eval_expr(op) + lhs = node.target + lhs.ctx = doc.Store(*lhs_pos) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + """The annotated assign visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.AnnAssign + The doc AST annotated assign node. + """ + lhs = node.target + rhs = self.eval_expr(node.value) + ann_var = self.visit_tvm_annotation(node.annotation) + if not isinstance(ann_var, Var): + self.report_error(node.annotation, "Annotation should be Var") + self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) + frame = T.LetStmt(rhs, var=ann_var) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + """The with visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.With + The doc AST with node. + """ + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, Frame): + self.report_error(item.context_expr, "Invalid context expression in the with-statement.") + rhs = stack.enter_context(frame) + if item.optional_vars is not None: + self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + """The function definition visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.FunctionDef + The doc AST function definition node. + """ + supplied_annotation = self.function_annotations + func_annotation = supplied_annotation.get(node.name, {}) + privacy = find_decorator_annotation(node, "private", default=False) + self.function_annotations = None + with self.var_table.with_frame(): + self.var_table.add("range", T.serial) + with T.prim_func(is_private=privacy): + T.func_name(node.name) + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + T.func_ret(ret_type) + with self.with_dispatch_token("tir"): + # TODO: handle different types of arguments: + # - vararg: arg | None + # - kwonlyargs: list[arg] + # - kw_defaults: list[expr | None] + # - kwarg: arg | None + # - defaults: list[expr] + # - posonlyargs: list[arg] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation required for function parameters.") + try: + ann = self.eval_expr(arg.annotation) + if callable(ann): + ann = ann() + except Exception: # pylint: disable=broad-except + ann = func_annotation.get(arg.arg, None) + if ann is None: + raise + param = T.arg(arg.arg, ann) + self.var_table.add(arg.arg, param) + self.visit_body(node.body) + self.function_annotations = supplied_annotation + + +@dispatch.register(token="tir", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr): + """The TVM annotation visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.expr + The doc AST expr node. + """ + annotation = self.eval_expr(node) + if callable(annotation): + annotation = annotation() + return annotation + + +@dispatch.register(token="tir", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + """The expr statement visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Expr + The doc AST Expr node. + """ + + res = self.eval_expr(node.value) + if res is None: + pass + elif isinstance(res, Frame): + res.add_callback(partial(res.__exit__, None, None, None)) + res.__enter__() + elif isinstance(res, PrimExpr): + T.evaluate(res) + elif isinstance(res, (int, bool)): + T.evaluate(tvm.tir.const(res)) + elif isinstance(res, (tvm.relay.Call, tvm.relax.Call)) and not res.args: + # Using GlobalVar.__call__ with no arguments is ambiguous, as + # each IR has a different function Call representation. If + # this occurs, convert to the TIR representation. + T.evaluate(tvm.tir.call_tir(res.op)) + elif isinstance(res, str): + # Ignore docstrings + pass + elif isinstance(res, tvm.tir.stmt.BufferStore): + T.buffer_store(res.buffer, res.value, res.indices, res.predicate) + else: + self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") + + +@dispatch.register(token="tir", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + """The if visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.If + The doc AST if node. + """ + with self.var_table.with_frame(): + predicate = self.eval_expr(node.test) + if isinstance(predicate, (PrimExpr, tvm.tir.expr.ExprOp)): + with T.If(self.eval_expr(node.test)): + with T.Then(): + with self.var_table.with_frame(): + self.visit_body(node.body) + if node.orelse: + with T.Else(): + with self.var_table.with_frame(): + self.visit_body(node.orelse) + elif isinstance(predicate, bool): + if predicate: + with self.var_table.with_frame(): + self.visit_body(node.body) + elif node.orelse: + with self.var_table.with_frame(): + self.visit_body(node.orelse) + else: + self.report_error(node.test, f"If condition must be a boolean expression, but got {predicate}") + + +@dispatch.register(token="tir", type_name="Assert") +def visit_assert(self: Parser, node: doc.Assert) -> None: + """The assert visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Assert + The doc AST assert node. + """ + cond = self.eval_expr(node.test) + msg = self.eval_expr(node.msg) + frame = T.Assert(cond, msg) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="Return") +def visit_return(self: Parser, node: doc.Return) -> None: + """The return visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + value = self.eval_expr(node.value) + if value is None: + self.report_error(node, "Expression to be returned must be a PrimExpr") + T.evaluate(tvm.tir.ret(value)) + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + supplied_annotation = self.function_annotations + func_annotation = supplied_annotation.get(node.name, {}) + + ret_type = None + with self.var_table.with_frame(): + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + arg_annotations = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation required for function parameters.") + try: + ann = self.eval_expr(arg.annotation) + if callable(ann): + ann = ann() + except Exception: # pylint: disable=broad-except + ann = func_annotation.get(arg.arg, None) + if ann is None: + raise + + IRBuilder.name(arg.arg, ann) + arg_annotations.append(ann) + + func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type) + return I.decl_function(node.name, func_signature) diff --git a/tilelang/original/tilelang/language/print_op.py b/tilelang/original/tilelang/language/print_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bbaa119ed55d7adbe0637ffff56d617e4d616454 --- /dev/null +++ b/tilelang/original/tilelang/language/print_op.py @@ -0,0 +1,220 @@ +""" +This module provides macros and utilities for debugging TileLang (tl) programs. +It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert. +""" + +from tvm import tir +from typing import Any +import tilelang.language as T +from tilelang.language.kernel import get_thread_bindings +from tilelang.language import copy, macro, serial, alloc_shared +from tilelang.language.utils import index_to_coordinates + + +@macro +def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: + """ + Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes. + + Parameters: + var (tir.PrimExpr): The variable or expression to be printed. + + Returns: + tir.PrimExpr: The TIR expression for the debug print operation. + """ + tir.call_extern("handle", "debug_print_var", msg, var) + + +@macro +def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: + """ + Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True. + + Parameters: + condition (tir.PrimExpr): A TIR expression representing the condition to check. + var (tir.PrimExpr): The variable or expression to be printed. + + Returns: + tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True. + """ + if condition: + tir.call_extern("handle", "debug_print_var", msg, var) + + +@macro +def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: + """ + Conditionally prints the values of a flattened TIR buffer if the condition is True. + """ + if condition: + # Iterate through the buffer elements and print each one. + for i in serial(elems): + coords = index_to_coordinates(i, buffer.shape) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + else: + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + + +@macro +def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: + """ + Conditionally prints the values of a flattened TIR buffer if the condition is True. + + Parameters: + condition (tir.PrimExpr): A TIR expression representing the condition to check. + buffer (tir.Buffer): The buffer whose values need to be printed. + elems (int): The number of elements in the buffer to print. + + Returns: + tir.PrimExpr: The TIR expression for the debug print operation. + """ + if condition: + # Iterate through the buffer elements and print each one. + for i in serial(elems): + coords = index_to_coordinates(i, buffer.shape) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + + +@macro +def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: + """ + Conditionally prints the values of a flattened TIR buffer if the condition is True. + + Parameters: + condition (tir.PrimExpr): A TIR expression representing the condition to check. + buffer (tir.Buffer): The buffer whose values need to be printed. + elems (int): The number of elements in the buffer to print. + + Returns: + tir.PrimExpr: The TIR expression for the debug print operation. + """ + smem = alloc_shared(buffer.shape, buffer.dtype, "shared") + copy(buffer, smem) + if condition: + # Iterate through the buffer elements and print each one. + for i in serial(elems): + coords = index_to_coordinates(i, buffer.shape) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords]) + + +@macro +def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: + """ + Conditionally prints the values of a flattened TIR buffer if the condition is True. + + Parameters: + condition (tir.PrimExpr): A TIR expression representing the condition to check. + buffer (tir.Buffer): The buffer whose values need to be printed. + elems (int): The number of elements in the buffer to print. + + Returns: + tir.PrimExpr: The TIR expression for the debug print operation. + """ + if condition: + # Iterate through the buffer elements and print each one. + for i in serial(elems): + coords = index_to_coordinates(i, buffer.shape) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + + +from tilelang.utils.target import check_cuda_availability +import warnings + +_IS_CUDA_AVAILABLE = check_cuda_availability() + + +@macro +def device_assert(condition: tir.PrimExpr, msg: str = ""): + """ + Device-side assert emulation. + Emits a device-side assert call on CUDA targets when CUDA is available. + The assert is always enabled and cannot be disabled at runtime. + """ + if _IS_CUDA_AVAILABLE: + if msg == "": + T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition) + else: + warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2) + T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg) + + +def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr: + """ + A generic print function that handles both TIR buffers and primitive expressions. + + - If the input is a TIR buffer, it prints its values, but only on the first thread (tx=0, ty=0, tz=0). + - If the input is a TIR primitive expression, it prints its value directly. + + Parameters: + obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr. + msg (str): An optional message to include in the print statement. + warp_group_id (int): The warp group id to print. + warp_id (int): The warp id to print. + print thread will be warp_group_id * warp_group_size + warp_id. + + Returns: + tir.PrimExpr: The TIR expression for the debug print operation. + + Raises: + ValueError: If the input object type is unsupported. + """ + if isinstance(obj, tir.Buffer): + # Buffers must be printed in just one thread to avoid duplicate outputs. + # Retrieve the thread bindings for thread x, y, and z. + tx, ty, tz = get_thread_bindings() + warp_group_size = 128 + warp_size = 32 + main_lane = warp_group_id * warp_group_size + warp_id * warp_size + + # Flatten the buffer for consistent printing. This assumes a 1D flattened buffer. + buffer = obj + if buffer.scope() == "local": + # Get the number of elements in the buffer. + elems = 1 + for dim in buffer.shape: + elems *= dim + condition = True + if not msg: + msg = f"buffer<{buffer.name}, {buffer.dtype}>" + return print_local_buffer_with_condition(condition, buffer, elems, msg) + elif buffer.scope() == "local.fragment": + # Get the number of elements in the buffer. + elems = 1 + for dim in buffer.shape: + elems *= dim + + # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. + condition = tx == main_lane and ty == 0 and tz == 0 + if not msg: + msg = f"buffer<{buffer.name}, {buffer.dtype}>" + return print_fragment_buffer_with_condition(condition, buffer, elems, msg) + elif buffer.scope() in {"shared", "shared.dyn"}: + # Get the number of elements in the buffer. + elems = 1 + for dim in buffer.shape: + elems *= dim + + # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. + condition = tx == main_lane and ty == 0 and tz == 0 + if not msg: + msg = f"buffer<{buffer.name}, {buffer.dtype}>" + return print_shared_buffer_with_condition(condition, buffer, elems, msg) + elif buffer.scope() == "global": + # Get the number of elements in the buffer. + elems = 1 + for dim in buffer.shape: + elems *= dim + condition = True + return print_global_buffer_with_condition(condition, buffer, elems, msg) + else: + raise ValueError(f"Unsupported buffer scope: {buffer.scope()}") + + elif isinstance(obj, tir.PrimExpr): + if not msg: + msg = f"expr<{obj}>" + # Directly print primitive expressions. + return print_var(obj, msg) + + else: + # Unsupported object type. + raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.") diff --git a/tilelang/original/tilelang/language/proxy.py b/tilelang/original/tilelang/language/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..b739de6b54871a28dc5f85acf03dadc0d5bb12d2 --- /dev/null +++ b/tilelang/original/tilelang/language/proxy.py @@ -0,0 +1,279 @@ +"""Buffer/Tensor proxy in TileLang.""" + +from __future__ import annotations + +from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar +from collections.abc import Sequence +from typing_extensions import Self + +from tvm import tir +from tvm.tir import Var, PrimExpr +from tvm.script.ir_builder.tir import buffer, handle, match_buffer +from tilelang.utils import deprecated + + +class BufferProxy: + """Buffer proxy class for constructing tir buffer.""" + + # Index via T.Buffer(...) + @deprecated("T.Buffer(...)", "T.Tensor(...)") + def __call__( + self, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> tir.Buffer: + return buffer( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + # Index via T.Buffer[...] + @deprecated("T.Buffer[...]", "T.Tensor(...)") + def __getitem__(self, keys) -> tir.Buffer: + if not isinstance(keys, tuple): + return self(keys) + if len(keys) >= 2 and not isinstance(keys[1], str): + return self(keys) + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member + + def from_ptr( + self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> Buffer: + """Create a buffer from a pointer, shape, and data type. + + Args: + pointer_var: The pointer variable + shape: The shape of the buffer + dtype: The data type of the buffer (default: float32) + + Returns: + A buffer created from the given parameters + """ + return match_buffer(pointer_var, shape, dtype=dtype, strides=strides) + + +class BaseTensorProxy: + """Base proxy class for tensor types with configurable defaults. + + This class serves as a foundation for different tensor proxy types, providing + customizable default values for scope, alignment, and offset factors. It implements + the core functionality for creating TIR buffers with specific memory configurations. + """ + + default_scope = "global" + default_align = 0 + default_offset_factor = 0 + + def __call__( + self, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope=None, # Changed to None to use class default + align=None, + offset_factor=None, + buffer_type="", + axis_separators=None, + ) -> tir.Buffer: + # Use class defaults if not specified + scope = scope or self.default_scope + align = align or self.default_align + offset_factor = offset_factor or self.default_offset_factor + + return buffer( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def __getitem__(self, keys) -> tir.Buffer: + assert isinstance(keys, tuple) + # Single argument (the shape) + if all([type(s) not in (tuple, str, list) for s in keys]): + keys = (keys,) + return self(*keys) + + def from_ptr( + self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> tir.Buffer: + """Create a buffer from a pointer, shape, and data type. + + Args: + pointer_var: The pointer variable + shape: The shape of the buffer + dtype: The data type of the buffer (default: float32) + + Returns: + A buffer created from the given parameters + """ + return match_buffer(pointer_var, shape, dtype=dtype, strides=strides) + + +class TensorProxy(BaseTensorProxy): + """Main tensor proxy class for global scope buffers. + + This class implements the default tensor proxy with global memory scope, + the tensor should be by default contiguous. + """ + + @staticmethod + def _construct_strides(shape: tuple[Any]): + s, strides = 1, [1] + for dim in shape[:0:-1]: + s *= dim + strides.append(s) + return tuple(reversed(strides)) + + def __call__(self, shape: tuple[Any] | PrimExpr | int, dtype: str = "float32", data=None, scope=None) -> tir.Buffer: + if isinstance(shape, (int, PrimExpr)): + shape = (shape,) + return super().__call__(shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data, scope=scope) + + +class StridedTensorProxy(BaseTensorProxy): + """Main tensor proxy class for global scope buffers, with strides supported. + + This class implements the default tensor proxy with global memory scope, with the stride information required. + """ + + def __call__(self, shape: tuple[Any], strides: tuple[Any], dtype: str = "float32", scope=None) -> tir.Buffer: + if len(shape) != len(strides): + raise ValueError("Invalid shape/strides' dimensions") + return super().__call__(shape, dtype=dtype, strides=strides, scope=scope) + + +class FragmentBufferProxy(BaseTensorProxy): + """Proxy class for fragment memory buffers. + + This class represents tensor proxies specifically for local fragment memory, + typically used in GPU tensor core operations. + """ + + default_scope = "local.fragment" + + +class SharedBufferProxy(BaseTensorProxy): + """Proxy class for shared memory buffers. + + This class represents tensor proxies for dynamic shared memory, + commonly used in GPU shared memory operations. + """ + + default_scope = "shared.dyn" + + +class LocalBufferProxy(BaseTensorProxy): + """Proxy class for local memory buffers. + + This class represents tensor proxies for local memory scope, + typically used for temporary computations in GPU kernels. + """ + + default_scope = "local" + + +Buffer = BufferProxy() # pylint: disable=invalid-name +# Tensor is an alias for Buffer +# Because when user do jit compile, the input and output will +# be mapped with torch.Tensor. +if TYPE_CHECKING: + + class BaseTensor: + def __class_getitem__(cls, key): + return cls + + def __getitem__(self, key) -> Any: ... + + def __setitem__(self, key, value) -> None: ... + + def __init__( + self, + shape: Sequence[SupportsIndex], + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope=None, # Changed to None to use class default + align=None, + offset_factor=None, + buffer_type="", + axis_separators=None, + ): ... + + @classmethod + def from_ptr( + cls, pointer_var: Var, shape: Sequence[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> Self: ... + + class Tensor(BaseTensor): ... + + class StridedTensor(BaseTensor): ... + + class FragmentBuffer(BaseTensor): ... + + class SharedBuffer(BaseTensor): ... + + class LocalBuffer(BaseTensor): ... + + _T = TypeVar("_T") + + class Ref(Generic[_T], tir.Var): ... +else: + Tensor = TensorProxy() # pylint: disable=invalid-name + StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name + FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name + SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name + LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name + + class Ref: ... + + +def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: + """Create a TIR var that represents a pointer. + + Parameters + ---------- + dtype: str + The data type of the pointer. + + storage_scope: str + The storage scope of the pointer. + + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + + Returns + ------- + res : PrimExpr + The new tir.Var with type handle or casted expression with type handle. + """ + return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var) + + +def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: + return Tensor.from_ptr(ptr, shape, dtype, strides) diff --git a/tilelang/original/tilelang/language/random.py b/tilelang/original/tilelang/language/random.py new file mode 100644 index 0000000000000000000000000000000000000000..a76625be2ef33d5b8fd10b1ff20a9d6e52483aad --- /dev/null +++ b/tilelang/original/tilelang/language/random.py @@ -0,0 +1,44 @@ +from tvm import tir +import tilelang.language as T + + +# https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview +def rng_init(seed, seq=None, off=0): + """Initialize CUDA curand random number generator state + + Parameters + ---------- + seed : PrimExpr + Random seed value. + seq : PrimExpr + Sequence number for parallel random number generation. + off : PrimExpr + Offset number for parallel random number generation. + + Returns + ------- + state : PrimExpr + The random number generator state handle. + """ + seed = tir.convert(seed) + if seq is None: + bx = T.get_block_binding() + ex = T.kernel.get_thread_extent() + tx = T.get_thread_binding() + id = tx + bx * ex + seq = tir.convert(id) + else: + seq = tir.convert(seq) + off = tir.convert(off) + return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off) + + +def rng_rand(): + """Generate a 32-bit unsigned random integer + + Returns + ------- + random_value : PrimExpr + A 32-bit unsigned random integer. + """ + return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand")) diff --git a/tilelang/original/tilelang/language/reduce_op.py b/tilelang/original/tilelang/language/reduce_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9db56df0d14f0923db59eb3f3444ba9cd3f45f65 --- /dev/null +++ b/tilelang/original/tilelang/language/reduce_op.py @@ -0,0 +1,464 @@ +"""Reduce operations exposed on the TileLang language surface.""" + +from __future__ import annotations +from tvm import tir +from tilelang.language import copy, macro, alloc_shared, alloc_fragment +from tilelang.utils.language import to_buffer_region, retrieve_shape, _get_buffer +from tilelang.utils.language import is_shared, is_fragment +from tvm.script.ir_builder import IRBuilder + + +def _legalize_dim(buffer: tir.Buffer, dim: int): + if dim < 0: + dim = len(buffer.shape) + dim + return dim + + +_REDUCE_OP_KEY = "tl.tileop.reduce" + + +def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): + """Perform a reduction operation on a buffer along a specified dimension. + + Args: + buffer (tir.Buffer): Input buffer to reduce + out (tir.Buffer): Output buffer to store results + reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum') + dim (int): Dimension along which to perform reduction + clear (bool): Whether to initialize the output buffer before reduction + + Returns: + tir.Call: Handle to the reduction operation + """ + # input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y] + expected_shapes = [buffer.shape[:dim] + buffer.shape[dim + 1 :], buffer.shape[:dim] + [1] + buffer.shape[dim + 1 :]] + if list(out.shape) not in expected_shapes: + expected_shapes_str = " or ".join(map(str, expected_shapes)) + raise ValueError( + f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " + f"output shape is {out.shape}, expected shapes are {expected_shapes_str}" + ) + + @macro + def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): + if is_shared(buffer) and is_shared(out): + red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) + red_frag_out = alloc_fragment(out.shape, out.dtype) + + # rename buffers + IRBuilder.name(buffer.name + "_frag", red_frag_in) + IRBuilder.name(out.name + "_frag", red_frag_out) + + copy(buffer, red_frag_in) + tir.call_intrin( + "handle", + tir.op.Op.get(_REDUCE_OP_KEY), + to_buffer_region(red_frag_in, access_type="r"), + to_buffer_region(red_frag_out, access_type="w"), + reduce_type, + dim, + clear, + ) + copy(red_frag_out, out) + elif is_shared(buffer) and is_fragment(out): + red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) + IRBuilder.name(buffer.name + "_frag", red_frag_in) + + copy(buffer, red_frag_in) + tir.call_intrin( + "handle", + tir.op.Op.get(_REDUCE_OP_KEY), + to_buffer_region(red_frag_in, access_type="r"), + to_buffer_region(out, access_type="w"), + reduce_type, + dim, + clear, + ) + elif is_fragment(buffer) and is_shared(out): + red_frag_out = alloc_fragment(out.shape, out.dtype) + IRBuilder.name(out.name + "_frag", red_frag_out) + + tir.call_intrin( + "handle", + tir.op.Op.get(_REDUCE_OP_KEY), + to_buffer_region(buffer, access_type="r"), + to_buffer_region(red_frag_out, access_type="w"), + reduce_type, + dim, + clear, + ) + copy(red_frag_out, out) + elif is_fragment(buffer) and is_fragment(out): + tir.call_intrin( + "handle", + tir.op.Op.get(_REDUCE_OP_KEY), + to_buffer_region(buffer, access_type="r"), + to_buffer_region(out, access_type="w"), + reduce_type, + dim, + clear, + ) + else: + raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}") + + return reduce_macro(buffer, out, reduce_type, dim, clear) + + +def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce max on input buffer, store the result to output buffer + + Parameters + ---------- + buffer : Buffer + The input buffer. + out : Buffer + The output buffer. + dim : int + The dimension to perform reduce on + clear : bool + If set to True, the output buffer will first be initialized to -inf. + Returns + ------- + handle : PrimExpr + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "max", dim, clear) + + +def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce min on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + clear (bool, optional): If True, output buffer will be initialized to inf. Defaults to True. + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "min", dim, clear) + + +def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce sum on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + clear (bool, optional): If True, output buffer will be cleared before reduction. + If False, results will be accumulated on existing values. + Defaults to True. + Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because + during warp reduction, the same value would be accumulated multiple times (number of threads + in the warp). Therefore, the implementation with clear=True follows these steps: + 1. create a temp buffer with same shape and dtype as out + 2. copy out to temp buffer + 3. call reduce_sum with temp buffer and out + 4. Add temp buffer to out + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "sum", dim, clear) + + +def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1): + """Perform reduce absolute sum on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "abssum", dim, True) + + +def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce absolute max on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "absmax", dim, clear) + + +def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce bitwise-and on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "bitand", dim, clear) + + +def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce bitwise-or on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "bitor", dim, clear) + + +def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform reduce bitwise-xor on input buffer, store the result to output buffer. + + Args: + buffer (tir.Buffer): The input buffer + out (tir.Buffer): The output buffer + dim (int): The dimension to perform reduce on + + Returns: + tir.Call: Handle to the reduction operation + """ + dim = _legalize_dim(buffer, dim) + return reduce(buffer, out, "bitxor", dim, clear) + + +@macro +def cumsum_fragment( + src: tir.Buffer, + dst: tir.Buffer, + dim: int, + reverse: bool, +) -> tir.PrimExpr: + """ + Compute cumulative sum for fragment buffers by copying to shared memory first. + + This macro handles cumulative sum operations on fragment buffers by first copying + the data to shared memory, performing the cumsum operation, and then copying back. + + Args: + src: Source buffer (Buffer, BufferRegion, or BufferLoad) containing input data. + dst: Destination buffer (Buffer, BufferRegion, or BufferLoad) for output data. + dim: Dimension along which to compute cumulative sum. + reverse: If True, compute cumulative sum in reverse order. + + Returns: + tir.PrimExpr: A handle to the cumulative sum operation. + """ + src_shape = retrieve_shape(src) + src_buffer = _get_buffer(src) + # Get dtype from the buffer + if isinstance(src, tir.Buffer): + dtype = src.dtype + else: + dtype = src_buffer.dtype + cumsum_smem = alloc_shared(src_shape, dtype, "shared.dyn") + copy(src, cumsum_smem) + tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.cumsum"), + to_buffer_region(cumsum_smem, access_type="r"), + to_buffer_region(cumsum_smem, access_type="w"), + dim, + reverse, + ) + copy(cumsum_smem, dst) + + +def cumsum( + src: tir.Buffer | tir.BufferRegion | tir.BufferLoad, + dst: tir.Buffer | tir.BufferRegion | tir.BufferLoad | None = None, + dim: int = 0, + reverse: bool = False, +): + """ + Compute the cumulative sum of `src` along `dim`, writing results to `dst`. + + Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic. + + Supports Buffer, BufferRegion, and BufferLoad inputs, allowing operations on buffer slices/regions. + + Examples: + A 1D inclusive scan that writes the result into a separate shared-memory buffer: + + >>> import tilelang.language as T + >>> @T.prim_func + ... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")): + ... with T.Kernel(1, threads=128): + ... A_shared = T.alloc_shared((128,), "float32") + ... T.copy(A, A_shared) + ... T.cumsum(src=A_shared, dst=A_shared, dim=0) + ... T.copy(A_shared, B) + + A 2D prefix sum along the last dimension with reverse accumulation: + + >>> import tilelang.language as T + >>> @T.prim_func + ... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")): + ... with T.Kernel(1, 1, threads=256): + ... tile = T.alloc_shared((64, 64), "float16") + ... T.copy(A, tile) + ... T.cumsum(src=tile, dim=1, reverse=True) + ... T.copy(tile, B) + + Operating on a buffer region (slice): + + >>> import tilelang.language as T + >>> @T.prim_func + ... def kernel_region(InputG_fragment: T.Tensor((128,), "float32"), chunk_size: T.int32): + ... with T.Kernel(1, threads=128): + ... i = T.int32(0) + ... T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0) + + Returns: + tir.Call: A handle to the emitted cumulative-sum operation. + """ + + # Get shape from src (supports Buffer, BufferRegion, BufferLoad) + shape = retrieve_shape(src) + if dim >= len(shape) or dim < -len(shape): + raise ValueError(f"Dimension {dim} is out of bounds for buffer with shape {shape}") + if dim < 0: + dim = len(shape) + dim + + if dst is None: + dst = src + else: + # Validate that dst shape matches src shape + dst_shape = retrieve_shape(dst) + if len(dst_shape) != len(shape): + raise ValueError(f"cumsum dst shape {dst_shape} must match src shape {shape} (rank mismatch)") + # Check each dimension matches + for i in range(len(shape)): + if not tir.analysis.expr_deep_equal(dst_shape[i], shape[i]): + raise ValueError(f"cumsum dst shape {dst_shape} must match src shape {shape} (dim {i} mismatch)") + + # Check if src is a fragment buffer + if is_fragment(src): + return cumsum_fragment(src, dst, dim, reverse) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.cumsum"), + to_buffer_region(src, access_type="r"), + to_buffer_region(dst, access_type="w"), + dim, + reverse, + ) + + +def finalize_reducer(reducer: tir.Buffer): + """ + Finalize a reducer buffer by emitting the `tl.tileop.finalize_reducer` intrinsic. + + This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer. + The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR. + + Parameters: + reducer (tir.Buffer): Reducer buffer whose writable pointer will be finalized. + + Returns: + tir.Call: Handle to the finalize reducer intrinsic call. + """ + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.finalize_reducer"), + to_buffer_region(reducer, access_type="w"), + ) + + +def warp_reduce_sum(value: tir.PrimExpr): + """Perform warp reduction sum on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the sum of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced sum value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_sum"), value) + + +def warp_reduce_max(value: tir.PrimExpr): + """Perform warp reduction max on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the max of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced max value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_max"), value) + + +def warp_reduce_min(value: tir.PrimExpr): + """Perform warp reduction min on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the min of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced min value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_min"), value) + + +def warp_reduce_bitand(value: tir.PrimExpr): + """Perform warp reduction bitwise-and on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the bitwise-and of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced bitwise-and value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitand"), value) + + +def warp_reduce_bitor(value: tir.PrimExpr): + """Perform warp reduction bitwise-or on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the bitwise-or of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced bitwise-or value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitor"), value) diff --git a/tilelang/original/tilelang/language/symbolics.py b/tilelang/original/tilelang/language/symbolics.py new file mode 100644 index 0000000000000000000000000000000000000000..928edf82ce98ff5cb31d0ee56544e86d56260da1 --- /dev/null +++ b/tilelang/original/tilelang/language/symbolics.py @@ -0,0 +1,27 @@ +"""Symbolic variable helpers exposed on the TileLang language surface.""" + +from tvm import tir + +from tilelang.utils import deprecated + +__all__ = ["dynamic", "symbolic"] + + +def dynamic(name: str, dtype: str = "int32"): + """ + Create a TIR dynamic symbolic variable. + + Parameters: + name (str): Identifier for the variable in generated TIR. + dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32". + + Returns: + tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels. + """ + return tir.Var(name, dtype) + + +@deprecated("T.symbolic(...)", "T.dynamic(...)", "v0.1.9") +def symbolic(name: str, dtype: str = "int32"): + """Deprecated alias for `T.dynamic`.""" + return tir.Var(name, dtype) diff --git a/tilelang/original/tilelang/language/tir/__init__.py b/tilelang/original/tilelang/language/tir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a933f7ff65862ab21e2e5d1c2a93d40f11690a50 --- /dev/null +++ b/tilelang/original/tilelang/language/tir/__init__.py @@ -0,0 +1,2 @@ +from .entry import prim_func # noqa: F401 +from .ir import * # noqa: F401 diff --git a/tilelang/original/tilelang/language/tir/entry.py b/tilelang/original/tilelang/language/tir/entry.py new file mode 100644 index 0000000000000000000000000000000000000000..8d65786e44163450817c41258cad6f9ae6cc6428 --- /dev/null +++ b/tilelang/original/tilelang/language/tir/entry.py @@ -0,0 +1,117 @@ +from __future__ import annotations +import inspect +from typing import Callable + +import tvm.script.parser.tir.entry as _tir_entry +from tvm.tir.function import PrimFunc +from tvm.script.parser._core import parse, scan_macro, utils + + +def prim_func(func: Callable | None = None, private: bool = False, check_well_formed: bool = False) -> PrimFunc | Callable: + """The parsing method for tir prim func, by using `@prim_func` as decorator. + + Parameters + ---------- + func : Callable + The function to be parsed as prim func. + (Listed as optional to allow the decorator to be used + without arguments, like `@prim_func`, + or with an argument, `@prim_func(private=True)`) + + private : bool, optional + Whether the function should be treated as private. + A private function has no global symbol attribute; + if the function is not private, it will have a global symbol + matching the function name. + + Returns + ------- + res : Union[PrimFunc, Callable] + The parsed tir prim func. + """ + # pylint: disable=unused-argument + # (private will be used in the parser, but not immediately) + + # need to capture this var outside the wrapper because the wrapper + # adds to the stack + outer_stack = inspect.stack() + + def decorator_wrapper(func): + if not inspect.isfunction(func): + raise TypeError(f"Expect a function, but got: {func}") + nonlocal outer_stack + if utils.is_defined_in_class(outer_stack, func): + outer_stack = None + return func + outer_stack = None + f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed) + setattr(f, "__name__", func.__name__) # noqa: B010 + return f + + if func is not None: + # no optional args given => use wrapper directly + return decorator_wrapper(func) + else: + # if there is an optional arg given, return a new decorator + # that will then be invoked + setattr(decorator_wrapper, "dispatch_token", "tir") # noqa: B010 + return decorator_wrapper + + +setattr(prim_func, "dispatch_token", "tir") # noqa: B010 + + +def macro(*args, hygienic: bool = True) -> Callable: + """Decorator for macro definitions. + + Parameters + ---------- + hygienic: bool + Specifies whether the macro is hygienic or not. + A macro is hygienic if all symbols used in the macro's body are resolved + to values from the location of the macro definition. A non-hygienic macro + will have its symbols resolved to values at the time of the macro's use. + + Example: + ``` + import tvm + from tvm.script import tir as T + + x_value = 128 + + @T.macro(hygienic=True) + def static_capture(A, B): + B[()] = A[x_value] ### x_value binds to 128 + + @T.macro(hygienic=False) + def dynamic_capture(A, B): + B[()] = A[x_value] ### x_value will bind at the time of use + + + @T.prim_func + def use1(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None: + for x_value in T.serial(10): + static_capture(A, B) ### Produces B[()] = A[128] + + @T.prim_func + def use2(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None: + for x_value in T.serial(10): + dynamic_capture(A, B) ### Produces B[()] = A[x_value] + ``` + """ + + def _decorator(func: Callable) -> _tir_entry.TIRMacro: + source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) + obj = _tir_entry.TIRMacro(source, closure_vars, func, hygienic) + obj.__name__ = func.__name__ + return obj + + if len(args) == 0: + return _decorator + if len(args) == 1 and inspect.isfunction(args[0]): + return _decorator(args[0]) + + raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") + + +setattr(macro, "dispatch_token", "tir") # noqa: B010 diff --git a/tilelang/original/tilelang/language/tir/ir.py b/tilelang/original/tilelang/language/tir/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..7723da713c668fc7bb8fa262537751256f8def1b --- /dev/null +++ b/tilelang/original/tilelang/language/tir/ir.py @@ -0,0 +1,303 @@ +import tvm.script.ir_builder.tir.ir as _ir +from tvm.script.ir_builder.tir import frame +from tvm.tir import PrimExpr +from typing import Any +import tilelang.language.tir.op as _tir_op +import functools + + +def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: + """The serial For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ir.serial(start=start, stop=stop, annotations=annotations) + + +def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: + """The parallel For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ir.parallel(start=start, stop=stop, annotations=annotations) + + +def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: + """The vectorized For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ir.vectorized(start=start, stop=stop, annotations=annotations) + + +def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: + """The unrolled For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + # Ensure annotations has {"pragma_unroll_explicit": True} by default + if annotations is None: + annotations = {"pragma_unroll_explicit": False} + else: + # Add "pragma_unroll_explicit": True if not already present + annotations = dict(annotations) + annotations.setdefault("pragma_unroll_explicit", False) + return _ir.unroll(start=start, stop=stop, annotations=annotations) + + +def thread_binding( + start: PrimExpr, + stop: PrimExpr = None, + thread: str = None, + *, + annotations: dict[str, Any] = None, +) -> frame.ForFrame: + """The thread-binding For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + thread : str + The thread for loop variable to bind. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ir.thread_binding(start=start, stop=stop, thread=thread, annotations=annotations) + + +def grid(*extents: PrimExpr) -> frame.ForFrame: + """The grid For statement. + + Parameters + ---------- + extents : PrimExpr + The extents of the iteration. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ir.grid(*extents) + + +def _dtype_forward(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + args = (kwargs.pop("dtype"),) + args + return func(*args, **kwargs) + + return wrapped + + +def _op_wrapper(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped + + +abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin +acos = _op_wrapper(_tir_op.acos) +acosh = _op_wrapper(_tir_op.acosh) +address_of = _op_wrapper(_tir_op.address_of) +asin = _op_wrapper(_tir_op.asin) +asinh = _op_wrapper(_tir_op.asinh) +atan = _op_wrapper(_tir_op.atan) +atan2 = _op_wrapper(_tir_op.atan2) +atanh = _op_wrapper(_tir_op.atanh) +bitwise_and = _op_wrapper(_tir_op.bitwise_and) +bitwise_not = _op_wrapper(_tir_op.bitwise_not) +bitwise_or = _op_wrapper(_tir_op.bitwise_or) +bitwise_xor = _op_wrapper(_tir_op.bitwise_xor) +ceil = _op_wrapper(_tir_op.ceil) +clz = _op_wrapper(_tir_op.clz) +copysign = _op_wrapper(_tir_op.copysign) +cos = _op_wrapper(_tir_op.cos) +cosh = _op_wrapper(_tir_op.cosh) +erf = _op_wrapper(_tir_op.erf) +exp = _op_wrapper(_tir_op.exp) +exp2 = _op_wrapper(_tir_op.exp2) +exp10 = _op_wrapper(_tir_op.exp10) +floor = _op_wrapper(_tir_op.floor) +ceildiv = _op_wrapper(_tir_op.ceildiv) +floordiv = _op_wrapper(_tir_op.floordiv) +floormod = _op_wrapper(_tir_op.floormod) +fmod = _op_wrapper(_tir_op.fmod) +hypot = _op_wrapper(_tir_op.hypot) +if_then_else = _op_wrapper(_tir_op.if_then_else) +infinity = _op_wrapper(_tir_op.infinity) +isfinite = _op_wrapper(_tir_op.isfinite) +isinf = _op_wrapper(_tir_op.isinf) +isnan = _op_wrapper(_tir_op.isnan) +isnullptr = _op_wrapper(_tir_op.isnullptr) +ldexp = _op_wrapper(_tir_op.ldexp) +likely = _op_wrapper(_tir_op.likely) +log = _op_wrapper(_tir_op.log) +log1p = _op_wrapper(_tir_op.log1p) +log2 = _op_wrapper(_tir_op.log2) +log10 = _op_wrapper(_tir_op.log10) +lookup_param = _op_wrapper(_tir_op.lookup_param) +max_value = _op_wrapper(_tir_op.max_value) +min_value = _op_wrapper(_tir_op.min_value) +nearbyint = _op_wrapper(_tir_op.nearbyint) +nextafter = _op_wrapper(_tir_op.nextafter) +popcount = _op_wrapper(_tir_op.popcount) +pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin +q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) +q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) +ret = _op_wrapper(_tir_op.ret) +round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin +rsqrt = _op_wrapper(_tir_op.rsqrt) +shift_left = _op_wrapper(_tir_op.shift_left) +shift_right = _op_wrapper(_tir_op.shift_right) +sigmoid = _op_wrapper(_tir_op.sigmoid) +sin = _op_wrapper(_tir_op.sin) +sinh = _op_wrapper(_tir_op.sinh) +sqrt = _op_wrapper(_tir_op.sqrt) +tan = _op_wrapper(_tir_op.tan) +tanh = _op_wrapper(_tir_op.tanh) +trunc = _op_wrapper(_tir_op.trunc) +truncdiv = _op_wrapper(_tir_op.truncdiv) +truncmod = _op_wrapper(_tir_op.truncmod) +tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) +tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) +tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) +tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) +tvm_check_return = _op_wrapper(_tir_op.tvm_check_return) +call_packed = _op_wrapper(_tir_op.call_packed) +call_cpacked = _op_wrapper(_tir_op.call_cpacked) +call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) +call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered) +tvm_tuple = _op_wrapper(_tir_op.tvm_tuple) +tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set) +tvm_struct_get = _tir_op.tvm_struct_get +tvm_thread_invariant = _op_wrapper(_tir_op.tvm_thread_invariant) +tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce) +tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync) +tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync) +tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) +tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) +tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +tvm_storage_sync = _tir_op.tvm_storage_sync +tvm_warp_shuffle = _tir_op.tvm_warp_shuffle +tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up +tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down +tvm_warp_activemask = _tir_op.tvm_warp_activemask +ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) +ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) +ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier) +ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count) +ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier) +ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx) +ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier) +create_barriers = _op_wrapper(_tir_op.create_barriers) +assume = _op_wrapper(_tir_op.assume) +undef = _op_wrapper(_tir_op.undef) +TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) +TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) +start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) +end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +anylist_getitem = _op_wrapper(_tir_op.anylist_getitem) +anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) +anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) +anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) +vscale = _op_wrapper(_tir_op.vscale) + +reinterpret = _dtype_forward(_tir_op.reinterpret) +call_extern = _dtype_forward(_tir_op.call_extern) +call_intrin = _dtype_forward(_tir_op.call_intrin) +call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) +call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) +call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) +ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) +ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) +ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) +ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) +ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) +ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) +mma_store = _dtype_forward(_tir_op.mma_store) +mma_fill = _dtype_forward(_tir_op.mma_fill) +vectorlow = _dtype_forward(_tir_op.vectorlow) +vectorhigh = _dtype_forward(_tir_op.vectorhigh) +vectorcombine = _dtype_forward(_tir_op.vectorcombine) +tvm_mfma = _dtype_forward(_tir_op.tvm_mfma) +tvm_mmac = _dtype_forward(_tir_op.tvm_mmac) +tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store) +tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma) +tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store) diff --git a/tilelang/original/tilelang/language/tir/ir.pyi b/tilelang/original/tilelang/language/tir/ir.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7723f13782bb3e40f5ee4ba3b1242ff6360eb0e6 --- /dev/null +++ b/tilelang/original/tilelang/language/tir/ir.pyi @@ -0,0 +1,146 @@ +from typing import TypeVar, Literal +from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm + +_T = TypeVar("_T") + +def abs(x: _T, span: Span | None = None) -> _T: ... +def acos(x: _T) -> _T: ... +def acosh(x: _T) -> _T: ... +def address_of(buffer_load: BufferLoad, span: Span | None = None) -> PrimExpr: ... +def asin(x: _T) -> _T: ... +def asinh(x: _T) -> _T: ... +def atan(x: _T) -> _T: ... +def atan2(x1: _T, x2: _T) -> _T: ... +def atanh(x: _T) -> _T: ... +def bitwise_and(x: _T, y: _T, span: Span | None = None) -> _T: ... +def bitwise_not(x: _T, span: Span | None = None) -> _T: ... +def bitwise_or(x: _T, y: _T, span: Span | None = None) -> _T: ... +def bitwise_xor(x: _T, y: _T, span: Span | None = None) -> _T: ... +def ceil(x: _T, span: Span | None = None) -> _T: ... +def clz(x: _T) -> _T: ... +def copysign(x1: _T, x2: _T) -> _T: ... +def cos(x: _T) -> _T: ... +def cosh(x: _T) -> _T: ... +def erf(x: _T) -> _T: ... +def exp(x: _T) -> _T: ... +def exp2(x: _T) -> _T: ... +def exp10(x: _T) -> _T: ... +def floor(x: _T, span: Span | None = None) -> _T: ... +def ceildiv(lhs: _T, rhs: _T, span: Span | None = None) -> _T: ... +def floordiv(a: _T, b: _T, span: Span | None = None) -> _T: ... +def floormod(a: _T, b: _T, span: Span | None = None) -> _T: ... +def fmod(x: _T, y: _T) -> _T: ... +def hypot(x1: _T, x2: _T) -> _T: ... +def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None = None) -> _T: ... +def infinity(dtype: _T, span: Span | None = None) -> _T: ... +def isfinite(x: _T, span: Span | None = None) -> _T: ... +def isinf(x: _T, span: Span | None = None) -> _T: ... +def isnan(x: _T, span: Span | None = None) -> _T: ... +def isnullptr(x: _T, span: Span | None = None) -> _T: ... +def ldexp(x1: _T, x2: _T) -> _T: ... +def likely(cond: _T, span: Span | None = None) -> _T: ... +def log(x: _T) -> _T: ... +def log1p(x: _T) -> _T: ... +def log2(x: _T) -> _T: ... +def log10(x: _T) -> _T: ... +def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ... +def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def nearbyint(x: _T, span: Span | None = None) -> _T: ... +def nextafter(x1: _T, x2: _T) -> _T: ... +def popcount(x: _T) -> _T: ... +def pow(x: _T, y: _T, span: Span | None = None) -> _T: ... +def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ... +def q_multiply_shift_per_axis( + x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm +) -> PrimExpr: ... +def ret(val: _T) -> _T: ... +def round(x: _T, span: Span | None = None) -> _T: ... +def rsqrt(x: _T) -> _T: ... +def shift_left(x: _T, y: _T, span=None) -> _T: ... +def shift_right(x: _T, y: _T, span=None) -> _T: ... +def sigmoid(x: _T) -> _T: ... +def sin(x: _T) -> _T: ... +def sinh(x: _T) -> _T: ... +def sqrt(x: _T) -> _T: ... +def tan(x: _T) -> _T: ... +def tanh(x: _T) -> _T: ... +def trunc(x: _T, span: Span | None = None) -> _T: ... +def truncdiv(a: _T, b: _T, span: Span | None = None) -> _T: ... +def truncmod(a: _T, b: _T, span: Span | None = None) -> _T: ... +def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... +def tvm_throw_last_error() -> _T: ... +def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... +def tvm_stack_make_shape(*args) -> _T: ... +def tvm_stack_make_array( + data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset +) -> PrimExpr: ... +def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ... +def call_packed(*args, span=None) -> _T: ... +def call_cpacked(*args, span=None) -> _T: ... +def call_packed_lowered(*args, span=None) -> _T: ... +def call_cpacked_lowered(*args, span=None) -> _T: ... +def tvm_tuple(*value) -> _T: ... +def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... +def tvm_thread_invariant(cond: _T) -> _T: ... +def tvm_thread_allreduce(*freduce_args) -> _T: ... +def tvm_load_matrix_sync( + fragment: Var, + m: IntImm, + n: IntImm, + k: IntImm, + index: PrimExpr, + buffer_ptr: PrimExpr, + stride: PrimExpr, + layout: Literal["row_major", "column_major"], +) -> PrimExpr: ... +def tvm_mma_sync( + fragment_d: Var, + index_d: PrimExpr, + fragment_a: Var, + index_a: PrimExpr, + fragment_b: Var, + index_b: PrimExpr, + fragment_c: Var, + index_c: PrimExpr, +) -> PrimExpr: ... +def tvm_bmma_sync( + fragment_d: Var, + index_d: PrimExpr, + fragment_a: Var, + index_a: PrimExpr, + fragment_b: Var, + index_b: PrimExpr, + fragment_c: Var, + index_c: PrimExpr, +) -> PrimExpr: ... +def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ... +def tvm_store_matrix_sync( + fragment: Var, + m: IntImm, + n: IntImm, + k: IntImm, + index: PrimExpr, + buffer_ptr: PrimExpr, + stride: PrimExpr, + layout: Literal["row_major", "column_major"], +) -> PrimExpr: ... +def ptx_wait_group(num: int) -> PrimExpr: ... +def ptx_commit_group() -> _T: ... +def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... +def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ... +def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... +def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... +def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... +def create_barriers(barrier_count: int) -> PrimExpr: ... +def assume(cond: _T = None) -> _T: ... +def undef() -> _T: ... +def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ... +def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ... +def start_profile_intrinsic(id: int) -> PrimExpr: ... +def end_profile_intrinsic(id: int) -> PrimExpr: ... +def anylist_getitem(list_handle, index) -> PrimExpr: ... +def anylist_resetitem(list_handle, index) -> PrimExpr: ... +def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ... +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ... +def vscale() -> _T: ... diff --git a/tilelang/original/tilelang/language/tir/op.py b/tilelang/original/tilelang/language/tir/op.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6d544e1778c6ae6f13e98145680ee2bea3c9a9 --- /dev/null +++ b/tilelang/original/tilelang/language/tir/op.py @@ -0,0 +1,3481 @@ +from __future__ import annotations +from typing import Any +import tvm +from tvm.ir import PrimExpr +from tvm.ir.base import Span +from tvm.runtime import const +from tvm.tir.expr import IntImm, PrimExprWithOp +import tvm.tir.op as _tvm_op + + +def call_packed(*args, span=None): + """Build expression by call an external packed function. + + The argument to packed function can be Expr or Buffer. + The argument is the corresponding POD type when Expr is presented. + + When the argument is Buffer, the corresponding PackedFunc + will receive an TVMArrayHandle whose content is valid during the callback period. + If the PackedFunc is a python callback, then the corresponding argument is NDArray. + + Parameters + ---------- + args : list of Expr or Buffer. + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + te.extern : Create tensor with extern function call. + """ + return _tvm_op.call_packed(*args, span=span) + + +def call_cpacked(*args, span=None): + """Build expression by call an external packed function. + + Same as call_packed, except that the first argument is the function name + (as in call_extern), and the last argument is the resource handle. + + Parameters + ---------- + args : list of Expr or Buffer. + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + te.extern : Create tensor with extern function call. + """ + return _tvm_op.call_cpacked(*args, span=span) + + +def call_packed_lowered(*args, span=None): + """Lowered version of call packed. + The argument to packed function can be Expr or Buffer. + The argument is the corresponding POD type when Expr is presented. + When the argument is Buffer, the corresponding PackedFunc + will receive an TVMArrayHandle whose content is valid during the callback period. + If the PackedFunc is a python callback, then the corresponding argument is NDArray. + + Parameters + ---------- + args : list of Expr or Buffer. + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + te.extern : Create tensor with extern function call. + """ + return _tvm_op.call_packed_lowered(*args, span=span) + + +def call_cpacked_lowered(*args, span=None): + """Lowered version of call c-packed. + Same as call_packed, except that the first argument is the function name + (as in call_extern), and the last argument is the resource handle. + + Parameters + ---------- + args : list of Expr or Buffer. + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + te.extern : Create tensor with extern function call. + """ + return _tvm_op.call_cpacked_lowered(*args, span=span) + + +def call_intrin(dtype, func_name, *args, span=None): + """Build expression by calling an intrinsic function. + + Intrinsics can be overloaded with multiple data types via + the intrinsic translation rule. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The intrinsic function name. + + args : list + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.call_intrin(dtype, func_name, *args, span=span) + + +def call_pure_extern(dtype, func_name, *args, span=None): + """Build expression by calling a pure extern function. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The extern function name. + + args : list + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.call_pure_extern(dtype, func_name, *args, span=span) + + +def call_extern(dtype, func_name, *args, span=None): + """Build expression by calling a extern function. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The extern function name. + + args : list + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.call_extern(dtype, func_name, *args, span=span) + + +def call_llvm_intrin(dtype, name, *args, span=None): + """Build expression by calling a llvm intrinsic function + + Parameters + ---------- + dtype : str + The data type of the result. + + name : str + The name of the llvm intrinsic function. + + args : list + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.call_llvm_intrin(dtype, name, *args, span=span) + + +def call_llvm_pure_intrin(dtype, name, *args, span=None): + """Build expression by calling a pure llvm intrinsic function + + Parameters + ---------- + dtype : str + The data type of the result. + + name : str + The name of the llvm intrinsic function. + + args : list + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.call_llvm_pure_intrin(dtype, name, *args, span=span) + + +def tvm_check_return(expected, return_unexpected, nested_call): + """Return new on stack dtype[num] + Parameters + ---------- + expected : int + The expected return code. + return_unexpected : int + The unexpected return code. + nested_call : PrimExpr + The call expression to check return. + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_check_return(expected, return_unexpected, nested_call) + + +def tvm_stack_alloca(dtype_str, num): + """Return new on stack dtype[num] + + Parameters + ---------- + dtype_str : str + The data type of array. + + num : int + The size of array. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_stack_alloca(dtype_str, num) + + +def tvm_stack_make_shape(*args): + """Allocate a shape tuple on stack, return the handle + + Parameters + ---------- + args : int + The tuple shape. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_stack_make_shape(*args) + + +def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset): + """Allocate a NDArray(DLTensor) on stack, return the handle + + Parameters + ---------- + data : Expr + The data of array. + + shape : Expr + The shape of array. + + strides : Expr + The strides of array. + + ndim : Expr + The dimensions of array. + + arr_dtype : Expr + The data type of array. + + elem_offse : Expr + The element offset of array. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset) + + +def assume(cond=None): + """Provide a true statement that can be used for simplifications + + Parameters + ---------- + cond : Expr + The constraint condition. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.assume(cond) + + +def undef(): + """Returns an initialized but arbitrary value + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.undef() + + +def call_tir(global_var: tvm.ir.GlobalVar, *args): + """Performs a call into another PrimFunc in the same IRModule + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.call_tir(global_var, *args) + + +def start_profile_intrinsic(id): + """Start profile intrinsic. + Parameters + ---------- + id : int + The intrinsic id. + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.start_profile_intrinsic(id) + + +def end_profile_intrinsic(id): + """End profile intrinsic. + Parameters + ---------- + id : int + The intrinsic id. + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.end_profile_intrinsic(id) + + +def tvm_tuple(*value): + """Create a tuple structure in value field of AttrStmt + + Parameters + ---------- + value : Expr + The value in tuple. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_tuple(*value) + + +def tvm_struct_get(arr, index, field, dtype): + """Get struct field value in array + + Parameters + ---------- + dtype : str + The date type of the result. + + arr : StructType* + The array of struct. + + index : int + The index of struct. + + field : int + The field of struct. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_struct_get(arr, index, field, dtype) + + +def tvm_struct_set(arr, index, field, value): + """Set value in struct field in array + + Parameters + ---------- + arr : StructType* + The array of struct. + + index : int + The index of struct. + + field : int + The field of struct. + + value : Expr + The value to be set in field. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_struct_set(arr, index, field, value) + + +def address_of(buffer_load, span=None): + """Returns the address of an element in the buffer + + Parameters + ---------- + buffer_load: BufferLoad + The buffer load. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.address_of(buffer_load, span=span) + + +def lookup_param(param_name, span=None): + """Returns the param by name + + Parameters + ---------- + param_name : str + The name of param. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.lookup_param(param_name, span=span) + + +def tvm_thread_allreduce(*freduce_args): + """Perform allreduce inside threadblock. + + Parameters + ---------- + freduce_args : Expr + The args. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_thread_allreduce(*freduce_args) + + +def tvm_thread_invariant(cond): + """Mark condition as thread invariant. + + Parameters + ---------- + cond : Expr + The condition. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_thread_invariant(cond) + + +def tvm_storage_sync(storage_scope): + """Perform synchronization in specified scope. + + Parameters + ---------- + storage_scope : str + The storage scope to perform synchronization. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_storage_sync(storage_scope) + + +def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): + """Exchange value between threads inside a warp. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + warp_id : PrimExpr + The source lane index to fetch value. + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_warp_shuffle(mask, value, warp_id, width, warp_size) + + +def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): + """Copy value from a lane with lower (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = dst_lane_idx - src_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_warp_shuffle_up(mask, value, offset, width, warp_size) + + +def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): + """Copy value from a lane with higher (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = src_lane_idx - dst_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_warp_shuffle_down(mask, value, offset, width, warp_size) + + +def tvm_warp_activemask(): + """Return a 32-bit mask indicates currently active threads in a calling warp. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_warp_activemask() + + +def type_annotation(dtype): + """Create a type annotation expression + + Parameters + ---------- + dtype : Expr + The data type. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.type_annotation(dtype) + + +def tvm_access_ptr(ptype, data, offset, extent, rw_mask): + """Get head access address with memory access pattern info + + Parameters + ---------- + ptype : Expr + The data type of pointer. + + data : DType* + The data of pointer. + + offset : int + The offset of pointer. + + extent : int + The extent of pointer. + + rw_mask : int + The read write mask. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_access_ptr(ptype, data, offset, extent, rw_mask) + + +def tvm_throw_last_error(): + """Throw TVMGetLastError() + + Returns + ------- + ret : PrimExpr + The return expression + """ + return _tvm_op.tvm_throw_last_error() + + +def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): + """TVM intrinsic for tensor core load operators + + Parameters + ---------- + fragment : Var + The wmma fragment. + + m : UIntImm + The shape of wmma fragment. + + n : UIntImm + The shape of wmma fragment. + + k : UIntImm + The shape of wmma fragment. + + index : Expr + The fragment index. + + buffer_ptr : Expr + The fragment buffer pointer. + + stride : Expr + The fragment stride. + + layout : Literal["row_major", "column_major"] + The fragment layout. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout) + + +def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c): + """TVM intrinsic for tensor core mma_sync operators + + Parameters + ---------- + fragment_d : Var + The wmma fragment_d. + + index_d : Expr + The fragment_d index. + + fragment_a : Var + The wmma fragment_a. + + index_a : Expr + The fragment_a index. + + fragment_b : Var + The wmma fragment_b. + + index_b : Expr + The fragment_b index. + + fragment_c : Var + The wmma fragment_c. + + index_c : Expr + The fragment_c index. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c) + + +def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c): + """TVM intrinsic for tensor core bmma_sync operators + + Parameters + ---------- + fragment_d : Var + The bwmma fragment_d. + + index_d : Expr + The fragment_d index. + + fragment_a : Var + The bwmma fragment_a. + + index_a : Expr + The fragment_a index. + + fragment_b : Var + The bwmma fragment_b. + + index_b : Expr + The fragment_b index. + + fragment_c : Var + The bwmma fragment_c. + + index_c : Expr + The fragment_c index. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c) + + +def tvm_fill_fragment(fragment, m, n, k, index, value): + """TVM intrinsic for tensor core fill_fragment operators + + Parameters + ---------- + fragment : Var + The wmma fragment + + m : UIntImm + The shape of wmma fragment. + + n : UIntImm + The shape of wmma fragment. + + k : UIntImm + The shape of wmma fragment. + + index : Expr + The fragment index. + + value : Expr + The value to be filled in fragment. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_fill_fragment(fragment, m, n, k, index, value) + + +def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): + """TVM intrinsic for tensor core store operators + + Parameters + ---------- + fragment : Var + The wmma fragment. + + m : UIntImm + The shape of wmma fragment. + + n : UIntImm + The shape of wmma fragment. + + k : UIntImm + The shape of wmma fragment. + + index : Expr + The fragment index. + + buffer_ptr : Expr + The fragment buffer pointer. + + stride : Expr + The fragment stride. + + layout : Literal["row_major", "column_major"] + The fragment layout. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout) + + +def ptx_mma( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + saturate, + operator=None, +): + """TVM intrinsic for ptx tensor core mma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment A. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + saturate : bool + The optional saturation at the output. + + operator : Optional[Literal["xor", "and"]] + The 1-bit operator. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_mma( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + saturate, + operator, + ) + + +def ptx_mma_sp( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + metadata, + meta_index, + sparse_selector, + saturate, +): + """TVM intrinsic for sparse tensor core ptx instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment B. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + metadata : Expr + The metadata of operand. + + meta_index : Expr + The metadata index of operand. + + sparse_selector : Expr + The sparse selector indicating the thread that stores the metadata. + + saturate : bool + The optional saturation at the output. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_mma_sp( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + metadata, + meta_index, + sparse_selector, + saturate, + ) + + +def ptx_wgmma_ss( + dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_desc, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, +): + """TVM intrinsic for ptx tensor core wmma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-wmma + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.ptx_wgmma_ss"), + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_desc, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + +def ptx_wgmma_rs( + dtype, + wgmma_prefix, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, +): + return call_intrin( + dtype, + _tvm_op.Op.get("tl.ptx_wgmma_rs"), + wgmma_prefix, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + +def ptx_tcgen05_mma_ss( + kind_dtype, + desc_a, + A_offset, + desc_b, + B_offset, + C_ptr, + C_offset, + desc_val, + scale_out, + mask0, + mask1, + mask2, + mask3, + enable_ws=False, + ws=None, + warp_specialized=None, + variant=None, +): + """TVM intrinsic for tcgen05.mma shared-memory × shared-memory instructions. + + Expects 13 or 14 positional arguments: + (kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, + desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]). + Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`. + Alternatively, use `variant="ws"` (or "default"). + - kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16, + "tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4). + """ + # Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws + if ws is not None: + enable_ws = bool(ws) + if warp_specialized is not None: + enable_ws = bool(warp_specialized) + if variant is not None: + if isinstance(variant, str): + v = variant.lower() + if v in ("ws", "warp_specialized", "warp-specialized"): + enable_ws = True + elif v in ("default", "std", "ss"): + enable_ws = False + else: + raise ValueError(f"ptx_tcgen05_mma_ss: unknown variant: {variant}") + else: + # Treat non-string as truthy flag + enable_ws = bool(variant) + + return call_intrin( + "handle", + _tvm_op.Op.get("tl.ptx_tcgen05_mma_ss"), + kind_dtype, + desc_a, + A_offset, + desc_b, + B_offset, + C_ptr, + C_offset, + desc_val, + scale_out, + mask0, + mask1, + mask2, + mask3, + enable_ws, + ) + + +def ptx_tcgen05_mma_ts( + kind_dtype, + A_ptr, + A_offset, + desc_b, + B_offset, + C_ptr, + C_offset, + desc_val, + scale_out, + mask0, + mask1, + mask2, + mask3, +): + """TVM intrinsic for tcgen05.mma tensor-memory × shared-memory instructions. + + Expects 13 positional arguments: + (kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset, + desc_val, scale_out, mask0, mask1, mask2, mask3). + - kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16, + "tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4). + """ + return call_intrin( + "handle", + _tvm_op.Op.get("tl.ptx_tcgen05_mma_ts"), + kind_dtype, + A_ptr, + A_offset, + desc_b, + B_offset, + C_ptr, + C_offset, + desc_val, + scale_out, + mask0, + mask1, + mask2, + mask3, + ) + + +def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """TVM intrinsic for storing the result of PTX MMA into a destination pointer + + Parameters + ---------- + dtype : str + The data type of the result. + + m : IntImm + The shape of mma fragment. + + n : IntImm + The shape of mma fragment. + + dst_ptr : Var + The destination pointer variable. + + src_ptr : Var + The source pointer variable. + + src_offset : Expr + The source offset. + + dst_stride : Var + The destination stride. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride) + + +def mma_fill(dtype, local_size, local_ptr, offset): + """TVM intrinsic for zero-initalizing an MMA accumulation register + + Parameters + ---------- + dtype : str + The data type of the result. + + local_size : IntImm + The number of elements. + + local_ptr : Var + The destination pointer variable. + + offset : Expr + The destination offset. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.mma_fill(dtype, local_size, local_ptr, offset) + + +def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset): + """TVM intrinsic for ptx load matrix from shared memory + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + + Parameters + ---------- + dtype : str + The data type of the result. + + trans : bool + The matrix is loaded in column-major format. + + num : IntImm + The number of matrices. + + type : Literal[".b16"] + The data type of the matrices. + + local_ptr : Var + The local pointer variable. + + local_offset : Expr + The offset of local pointer. + + smem_ptr : Var + The shared memory pointer variable. + + smem_offset : Expr + The offset of shared memort pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) + + +def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes): + """TVM intrinsic for ptx async copy from global to shared memory using cp.async + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async + + Parameters + ---------- + dtype : str + The data type of the result. + + shared_ptr : Var + The shared memory pointer variable. + + shared_offset : Expr + The offset of shared memory pointer. + + global_ptr : Var + The global memory pointer variable. + + global_offset : Expr + The offset of global memory pointer. + + bytes : int + The data size to copy. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes) + + +def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id): + """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk + + Parameters + ---------- + dtype : str + The data type of the result. + + shared_ptr : Var + The shared memory pointer variable. + + shared_offset : Expr + The offset of shared memory pointer. + + global_ptr : Var + The global memory pointer variable. + + global_offset : Expr + The offset of global memory pointer. + + bytes : int + The data size to copy. + + barrier_id : int + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id) + + +def ptx_commit_group(): + """TVM intrinsic for ptx async copy commit + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_commit_group() + + +def ptx_wait_group(num): + """TVM intrinsic for ptx async copy wait + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group + + Parameters + ---------- + num : int + The number of the most recent uncommitted pending cp.async groups to wait. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_wait_group(num) + + +def tvm_mfma( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, +): + """TVM intrinsic for amd matrix core mfma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment A. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.tvm_mfma"), + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + ) + + +def tvm_mmac( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, +): + """TVM intrinsic for amd matrix core mfma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment A. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.tvm_mmac"), + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + ) + + +def tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """TVM intrinsic for storing the result of PTX MMA into a destination pointer + + Parameters + ---------- + dtype : str + The data type of the result. + + m : IntImm + The shape of mma fragment. + + n : IntImm + The shape of mma fragment. + + dst_ptr : Var + The destination pointer variable. + + src_ptr : Var + The source pointer variable. + + src_offset : Expr + The source offset. + + dst_stride : Var + The destination stride. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.tvm_mfma_store"), + m, + n, + dst_ptr, + src_ptr, + src_offset, + dst_stride, + ) + + +def tvm_rdna_wmma( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, +): + """TVM intrinsic for amd matrix core mfma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment A. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.tvm_rdna_wmma"), + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + ) + + +def tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """TVM intrinsic for storing the result of PTX MMA into a destination pointer + + Parameters + ---------- + dtype : str + The data type of the result. + + m : IntImm + The shape of mma fragment. + + n : IntImm + The shape of mma fragment. + + dst_ptr : Var + The destination pointer variable. + + src_ptr : Var + The source pointer variable. + + src_offset : Expr + The source offset. + + dst_stride : Var + The destination stride. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.tvm_rdna_wmma_store"), + m, + n, + dst_ptr, + src_ptr, + src_offset, + dst_stride, + ) + + +def ptx_cp_async_barrier(barrier_id): + """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive + + Parameters + ---------- + barrier_id : int + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_cp_async_barrier(barrier_id) + + +def ptx_init_barrier_thread_count(barrier_id, thread_count): + """TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init + + Parameters + ---------- + barrier_id : int + The ID of the barrier shared memory pointer. + + thread_count : int + Number of threads expected to arrive at the barrier. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_init_barrier_thread_count(barrier_id, thread_count) + + +def ptx_arrive_barrier(barrier_id): + """TVM intrinsic for ptx barrier arrival using mbarrier.arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive + + Parameters + ---------- + barrier_id : int + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_arrive_barrier(barrier_id) + + +def ptx_arrive_barrier_expect_tx(barrier_id, byte_count): + """TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation + + Parameters + ---------- + barrier_id : int + The ID of the barrier shared memory pointer. + + byte_count : int + Increases the tx count of the mbarrier object to track completion of + additional async transactions. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_arrive_barrier_expect_tx(barrier_id, byte_count) + + +def ptx_wait_barrier(barrier_id): + """TVM intrinsic for ptx barrier wait using mbarrier.try_wait + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait + + Parameters + ---------- + barrier_id : int + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.ptx_wait_barrier(barrier_id) + + +def create_barriers(barrier_count): + """TVM intrinsic to create N barriers + + Parameters + ---------- + barrier_count : int + The number of barriers to create. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.create_barriers(barrier_count) + + +def vectorlow(dtype, vec): + """Get the low level half of the vector + + Parameters + ---------- + dtype : str + The data type of the result. + + vec : list + The input vector. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.vectorlow(dtype, vec) + + +def vectorhigh(dtype, vec): + """Get the high level half of the vector + + Parameters + ---------- + dtype : str + The data type of the result. + + vec : list + The input vector. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.vectorhigh(dtype, vec) + + +def vectorcombine(dtype, vec1, vec2): + """Concat two vectors + + Parameters + ---------- + vec1 : list + The input vector. + + vec2 : list + The input vector. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.vectorcombine(dtype, vec1, vec2) + + +def ret(val): + """Create a tir return expression + + Parameters + ---------- + val : Expr + The returned tir expression, whose data type is int, float or void pointer. + + Returns + ------- + ret : PrimExpr + The return expression + """ + return _tvm_op.ret(val) + + +def any(*args, span=None): + """Create a new expression of the union of all conditions in the arguments + + Parameters + ---------- + args : list + List of symbolic boolean expressions + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + expr: Expr + Expression + """ + return _tvm_op.any(*args, span=span) + + +def all(*args, span=None): + """Create a new expression of the intersection of all conditions in the + arguments + + Parameters + ---------- + args : list + List of symbolic boolean expressions + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + expr: Expr + Expression + """ + return _tvm_op.all(*args, span=span) + + +def trace(args, trace_action="tvm.default_trace_action"): + """Trace tensor data at the runtime. + + The trace function allows to trace specific tensor at the + runtime. The tracing value should come as last argument. + The trace action should be specified, by default + tvm.default_trace_action is used. + + Parameters + ---------- + args : list of Expr or Buffers. + Positional arguments. + + trace_action : str. + The name of the trace action. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + tvm.tir.call_packed : Creates packed function. + """ + return _tvm_op.trace(args, trace_action) + + +def min_value(dtype, span=None): + """minimum value of dtype + + Parameters + ---------- + dtype : str + The data type. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The minimum value of dtype. + """ + return _tvm_op.min_value(dtype, span) + + +def max_value(dtype: str, span: Span | None = None) -> Any: + """maximum value of dtype + + Parameters + ---------- + dtype : str + The data type. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The maximum value of dtype. + """ + return _tvm_op.max_value(dtype, span) + + +def infinity(dtype: str, span: Span | None = None) -> Any: + """infinity value of dtype + + Parameters + ---------- + dtype : str + The data type. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The infinity value of dtype. + """ + return call_intrin(dtype, _tvm_op.Op.get("tl.infinity"), dtype, span=span) + + +def reinterpret(dtype, value, span: Span | None = None) -> Any: + """infinity value of dtype + + Parameters + ---------- + dtype : str + The data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The reinterpret cast value of dtype. + """ + return _tvm_op.reinterpret(dtype, value, span) + + +def exp(x): + """Take exponential of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.exp(x) + + +def exp2(x): + """Calculate 2**x + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.exp2(x) + + +def exp10(x): + """Calculate 10**x + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.exp10(x) + + +def erf(x): + """Take gauss error function of the input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.erf(x) + + +def tanh(x): + """Take hyperbolic tanh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.tanh(x) + + +def sigmoid(x): + """Quick function to get sigmoid + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.sigmoid(x) + + +def log(x): + """Take log of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.log(x) + + +def log2(x): + """Take log2 of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.log2(x) + + +def log10(x): + """Take log10 of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.log10(x) + + +def log1p(x): + """Take log(x + 1) with respect to input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.log1p(x) + + +def tan(x): + """Take tan of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.tan(x) + + +def cos(x): + """Take cos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.cos(x) + + +def cosh(x): + """Take cosh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.cosh(x) + + +def acos(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.acos(x) + + +def acosh(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.acosh(x) + + +def sin(x): + """Take sin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.sin(x) + + +def sinh(x): + """Take sinh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.sinh(x) + + +def asin(x): + """Take asin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.asin(x) + + +def asinh(x): + """Take asinh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.asinh(x) + + +def atan(x): + """Take atan of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.atan(x) + + +def atanh(x): + """Take atanh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.atanh(x) + + +def atan2(x1, x2): + """Take arctan2(x1, x2). + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.atan2(x1, x2) + + +def sqrt(x): + """Take square root of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.sqrt(x) + + +def rsqrt(x): + """Take reciprocal of square root of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.rsqrt(x) + + +def clz(x): + """Count leading zero bits of an integer x. + + Parameters + ---------- + x : PrimExpr + Input 32 or 64 bit integer. + The result is undefined if the input is 0. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.clz(x) + + +def floor(x: PrimExprWithOp, span=None): + """Take floor of float input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.floor(x, span) + + +def ceil(x, span=None): + """Take ceil of float input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.ceil(x, span) + + +def trunc(x, span=None): + """Get truncated value of the input. + + The truncated value of the scalar x is the + nearest integer i which is closer to zero than x is. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.trunc(x, span) + + +def abs(x, span=None): + """Get absolute value of the input element-wise. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.abs(x, span) + + +def bitwise_and(x, y, span=None): + """Take bitwise and of two values + + Parameters + ---------- + x : PrimExpr + Left operand + + y : PrimExpr + Right operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _tvm_op.bitwise_and(x, y, span) + + +def bitwise_not(x, span=None): + """Take bitwise not of input value + + Parameters + ---------- + x : PrimExpr + Input operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _tvm_op.bitwise_not(x, span) + + +def bitwise_or(x, y, span=None): + """Take bitwise or of two values + + Parameters + ---------- + x : PrimExpr + Left operand + + y : PrimExpr + Right operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _tvm_op.bitwise_or(x, y, span) + + +def bitwise_xor(x, y, span=None): + """Take bitwise xor of two values + + Parameters + ---------- + x : PrimExpr + Left operand + + y : PrimExpr + Right operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _tvm_op.bitwise_xor(x, y, span) + + +def round(x, span=None): + """Round elements of the array to the nearest integer. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.round(x, span) + + +def nearbyint(x, span=None): + """Round elements of the array to the nearest integer. + This intrinsic uses llvm.nearbyint instead of llvm.round + which is faster but will results different from te.round. + Notably nearbyint rounds according to the rounding mode, + whereas te.round (llvm.round) ignores that. + For differences between the two see: + https://en.cppreference.com/w/cpp/numeric/math/round + https://en.cppreference.com/w/cpp/numeric/math/nearbyint + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.nearbyint(x, span) + + +def nextafter(x1, x2): + """Return the next floating-point value after x1 towards x2. + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.nextafter(x1, x2) + + +def hypot(x1, x2): + """Equivalent to sqrt(x1**2 + x2**2), element-wise. + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.hypot(x1, x2) + + +def copysign(x1, x2): + """Change the sign of x1 to that of x2, element-wise. + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.copysign(x1, x2) + + +def ldexp(x1, x2): + """Returns x1 * (2 ** x2). + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.ldexp(x1, x2) + + +def likely(cond, span=None): + """Mark condition as likely. + + Parameters + ---------- + + cond : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The marked expression. + """ + return _tvm_op.likely(cond, span) + + +def isnan(x, span=None): + """Check if input value is Nan. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.isnan(x, span) + + +def isnullptr(x, span=None): + """Check if input value is nullptr. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.isnullptr(x, span) + + +def isfinite(x, span=None): + """Check if input value is finite. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.isfinite(x, span) + + +def isinf(x, span=None): + """Check if input value is infinite. + + Parameters + ---------- + x : PrimExpr + Input argument. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.isinf(x, span) + + +def pow_of_int(x: PrimExpr, y: int) -> PrimExpr: + """Fast power operation than pow(float, float). + + Args: + x (PrimExpr): Base value + y (int): Exponent value + """ + return call_intrin( + x.dtype, + tvm.tir.op.Op.get("tl.pow_of_int"), + x, + y, + ) + + +def power(x, y, span=None): + """x power y + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + The exponent + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + z : PrimExpr + The result. + """ + if isinstance(y, (int, IntImm)): + return pow_of_int(x, y) + return _tvm_op.power(x, y, span) + + +def pow(x, y, span=None): + """x power y + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + The exponent + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + z : PrimExpr + The result. + """ + if isinstance(y, (int, IntImm)): + return pow_of_int(x, y) + return _tvm_op.pow(x, y, span) + + +def popcount(x): + """Count the number of set bits in input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.popcount(x) + + +def q_multiply_shift(x, y, q, s): + """Execute a multiplication between two Q-numbers x and y + followed by a right shift s. The mathematical expression is: + + out = round(x*y*2^-s) + + More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + The rounding rule is to the nearest value, rounding half up + (i.e., round(x.1) = x and round (x.5) = x+1) + + Parameters + ---------- + x : PrimExpr + First Q-number + y : PrimExpr + Second Q-number + q : PrimExpr + Number of fractional bits in x and y. Needs to be > 0 + s : PrimExpr + Integer shift + + Returns + ------- + y : PrimExpr + The result. + """ + return _tvm_op.q_multiply_shift(x, y, q, s) + + +def q_multiply_shift_per_axis( + x: PrimExpr, + y: PrimExpr, + ls: PrimExpr, + rs: PrimExpr, + q: IntImm, + is_lshift_required: IntImm, + is_rshift_required: IntImm, +): + """Execute a multiplication between two Q-numbers x and y + + Parameters + ---------- + x : PrimExpr + First Q-number. + y : PrimExpr + Second Q-number. + ls : PrimExpr + Integer left shift. + rs : PrimExpr + Integer right shift. + q : IntImm + Number of fractional bits in x and y. Needs to be > 0. + is_lshift_required : IntImm + Whether we need to do left shift or not. + is_rshift_required : IntImm + Whether we need to do right shift or not. + + Returns + ------- + z : PrimExpr + The result. + """ + return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required) + + +def shift_left(x, y, span=None): + """Return the result of x left shifted by y bits. + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + Input argument. + + Returns + ------- + z : PrimExpr + The result. + """ + return _tvm_op.shift_left(x, y, span) + + +def shift_right(x, y, span=None): + """Return the result of x right shifted by y bits. + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + Input argument. + + Returns + ------- + z : PrimExpr + The result. + """ + return _tvm_op.shift_right(x, y, span) + + +def fmod(x, y): + """Return the remainder of x divided by y with the same sign as x. + + Parameters + ---------- + x : PrimExpr + Input argument. + y : PrimExpr + Input argument. + + Returns + ------- + z : PrimExpr + The result. + """ + return _tvm_op.fmod(x, y) + + +def if_then_else(cond, t, f, span=None): + """Conditional selection expression. + + Parameters + ---------- + cond : PrimExpr + The condition + + t : PrimExpr + The result expression if cond is true. + + f : PrimExpr + The result expression if cond is false. + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + result : Node + The result of conditional expression. + + Note + ---- + Unlike Select, if_then_else will not execute + the branch that does not satisfy the condition. + You can use it to guard against out of bound access. + Unlike Select, if_then_else cannot be vectorized + if some lanes in the vector have different conditions. + """ + return _tvm_op.if_then_else(cond, t, f, span) + + +def div(a, b, span=None): + """Compute a / b as in C/C++ semantics. + + Parameters + ---------- + a : PrimExpr + The left hand operand, known to be non-negative. + + b : PrimExpr + The right hand operand, known to be non-negative. + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + Note + ---- + When operands are integers, returns truncdiv(a, b, span). + """ + return _tvm_op.div(a, b, span) + + +def indexdiv(a, b, span=None): + """Compute floor(a / b) where a and b are non-negative. + + Parameters + ---------- + a : PrimExpr + The left hand operand, known to be non-negative. + + b : PrimExpr + The right hand operand, known to be non-negative. + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + Use this function to split non-negative indices. + This function may take advantage of operands' + non-negativeness. + """ + return _tvm_op.indexdiv(a, b, span) + + +def indexmod(a, b, span=None): + """Compute the remainder of indexdiv. a and b are non-negative. + + Parameters + ---------- + a : PrimExpr + The left hand operand, known to be non-negative. + + b : PrimExpr + The right hand operand, known to be non-negative. + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + Use this function to split non-negative indices. + This function may take advantage of operands' + non-negativeness. + """ + return _tvm_op.indexmod(a, b, span) + + +def truncdiv(a, b, span=None): + """Compute the truncdiv of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _tvm_op.truncdiv(a, b, span) + + +def truncmod(a, b, span=None): + """Compute the truncmod of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _tvm_op.truncmod(a, b, span) + + +def floordiv(a, b, span=None): + """Compute the floordiv of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _tvm_op.floordiv(a, b, span) + + +def floormod(a, b, span=None): + """Compute the floormod of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _tvm_op.floormod(a, b, span) + + +def ceildiv(lhs, rhs, span=None): + """Generic ceildiv operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + op : tvm.Expr + The result Expr of ceildiv operation. + """ + return _tvm_op.ceildiv(lhs, rhs, span) + + +def comm_reducer(fcombine, fidentity, name="reduce"): + """Create a commutative reducer for reduction. + + Parameters + ---------- + fcombine : function(Expr -> Expr -> Expr) + A binary function which takes two Expr as input to return a Expr. + + fidentity : function(str -> Expr) + A function which takes a type string as input to return a const Expr. + + Returns + ------- + reducer : function + A function which creates a reduce expression over axis. + There are two ways to use it: + + 1. accept (expr, axis, where) to produce an Reduce Expr on + specified axis; + 2. simply use it with multiple Exprs. + + Example + ------- + .. code-block:: python + + n = te.var("n") + m = te.var("m") + mysum = te.comm_reducer(lambda x, y: x+y, + lambda t: tvm.tir.const(0, dtype=t), name="mysum") + A = te.placeholder((n, m), name="A") + k = te.reduce_axis((0, m), name="k") + B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B") + """ + return _tvm_op.comm_reducer(fcombine, fidentity, name) + + +def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint): + """Backend function to allocate temporal workspace + + Parameters + ---------- + device_type : int + The device type which the space will be allocated. + + device_id : int + The device id which the space will be allocated. + + nbytes : int + The size of the space requested. + + dtype_code_hint : int + The type code of the array elements. Only used in certain backends such as OpenGL. + + dtype_bits_hint : int + The type bits of the array elements. Only used in certain backends such as OpenGL. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint) + + +def TVMBackendFreeWorkspace(device_type, device_id, ptr): + """Backend function to free temporal workspace. + + Parameters + ---------- + device_type : int + The device type which the space will be allocated. + + device_id : int + The device id which the space will be allocated. + + ptr : Var + The result allocated space pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.TVMBackendFreeWorkspace(device_type, device_id, ptr) + + +def anylist_getitem(list_handle, index): + """Returns an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.anylist_getitem(list_handle, index) + + +def anylist_resetitem(list_handle, index): + """Reset an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.anylist_resetitem(list_handle, index) + + +def anylist_setitem_call_packed(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.anylist_setitem_call_packed(list_handle, index, func_name, *args) + + +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return _tvm_op.anylist_setitem_call_cpacked(list_handle, index, func_name, *args) + + +def vscale(): + """Get the target's vscale value. It will be lowered to llvm.vscale intrinsic + (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) + Returns + ------- + call : PrimExpr + Call to the vscale intrinsic + """ + return _tvm_op.vscale() + + +# pylint: disable=unnecessary-lambda +sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") +min = comm_reducer(lambda x, y: _tvm_op._OpMin(x, y, None), max_value, name="min") +max = comm_reducer(lambda x, y: _tvm_op._OpMax(x, y, None), min_value, name="max") diff --git a/tilelang/original/tilelang/language/utils.py b/tilelang/original/tilelang/language/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6829419b5a21f0bbc3cde1759066fbd66f59fc --- /dev/null +++ b/tilelang/original/tilelang/language/utils.py @@ -0,0 +1,101 @@ +from tilelang import tvm as tvm +from tvm import tir +from tvm.tir import PrimExpr, BufferLoad, op +from tilelang import language as T + + +def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): + """Create a tl.region call for a BufferLoad and extents.""" + access_type = {"r": 1, "w": 2, "rw": 3}[access_type] + return T.call_intrin("handle", op.Op.get("tl.tileop.region"), buffer, access_type, *args) + + +def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]): + """Convert a BufferLoad to a tl.region call with explicit extents.""" + indices = list(load.indices) + if len(indices) > len(extents): + extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))] + list(extents) + assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" + return region(load, access_type, *extents) + + +def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, extents: list[tir.PrimExpr]): + """Clamp extents and return a tl.region call.""" + mins = [r.min for r in buffer_region.region] + region_extents = [r.extent for r in buffer_region.region] + assert len(region_extents) >= len(extents), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" + clamped_extents = [ + tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] for i in range(len(region_extents)) + ] + return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents) + + +def index_to_coordinates(index, shape) -> list[PrimExpr]: + """ + Convert a flat (linear) index into multi-dimensional coordinates for a given shape. + + Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates. + + Parameters: + index (int or PrimExpr): The flat index to convert. + shape (Sequence[int]): The extents of each dimension (length >= 1). + + Returns: + List[PrimExpr]: Coordinates for each dimension in the same order as `shape`. + """ + coordinates = [] + dims = len(shape) + for i in range(dims): + coordinates.append(index % shape[dims - i - 1]) + index = index // shape[dims - i - 1] + coordinates.reverse() + return coordinates + + +def linear_index(*args: PrimExpr) -> PrimExpr: + """ + Compute a flat (linear) index from multi-dimensional coordinates and strides. + + The function accepts a sequence of PrimExpr arguments where the first portion are coordinates + and the trailing portion are the corresponding strides. The number of strides must equal + (number of coordinates - 1). The linear index is computed as: + + linear = coords[0] + for each (coord, stride) in zip(coords[1:], strides): + linear = linear * stride + coord + + Examples: + - linear_index(i) -> i + - linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed) + - linear_index(i, j, stride_j) -> i * stride_j + j + - linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k + - linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v + + Raises: + ValueError: If called with no arguments, or if the number of strides is not one less than + the number of coordinates. + + Returns: + PrimExpr: The computed linear index expression. + """ + n = len(args) + if n == 0: + raise ValueError("At least one index is required") + + if n == 1: + return args[0] + + # The first part is indices, the second part is strides (starting from the second dimension) + # A simpler way: the number of strides = total number of arguments - number of indices + # Actually, the args are designed as indices... + strides..., and the number of strides = number of indices - 1 + num_coords = (n + 1) // 2 + coords = args[:num_coords] + strides = args[num_coords:] + + if len(strides) != len(coords) - 1: + raise ValueError("Stride count must be one less than coordinate count") + + linear = coords[0] + for idx, stride in zip(coords[1:], strides): + linear = linear * stride + idx + return linear diff --git a/tilelang/original/tilelang/language/v2/__init__.py b/tilelang/original/tilelang/language/v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed4834388c7374215ed5d0158aa11d0d6fe3ca9 --- /dev/null +++ b/tilelang/original/tilelang/language/v2/__init__.py @@ -0,0 +1,2 @@ +from .builder import prim_func, macro, PrimFunc, PrimFuncCreater, Ref # noqa: F401 +from .dtypes import * diff --git a/tilelang/original/tilelang/language/v2/annot.py b/tilelang/original/tilelang/language/v2/annot.py new file mode 100644 index 0000000000000000000000000000000000000000..bac92142ce1c21cd17a2f59a1ffc26c5d91b3d8f --- /dev/null +++ b/tilelang/original/tilelang/language/v2/annot.py @@ -0,0 +1,714 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from tvm import tir +from tvm.ir.expr import PrimExpr +from tvm.script.ir_builder.tir import buffer +from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING + +# Python 3.9 compatibility for advanced typing features +try: + from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined] +except Exception: # Python < 3.10 for ParamSpec, < 3.11 for Unpack/TypeVarTuple/Self + from typing_extensions import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore + +# Compatibility for generic alias detection across Python versions +try: + from typing import _GenericAlias as _TypingGenericAlias # type: ignore[attr-defined] +except Exception: + _TypingGenericAlias = None # type: ignore +try: + # Builtin generic alias type for e.g. tuple[int] + from types import GenericAlias as _TypesGenericAlias # type: ignore[attr-defined] +except Exception: + _TypesGenericAlias = None # type: ignore + +_GenericAliasTypes = tuple(t for t in (_TypingGenericAlias, _TypesGenericAlias) if t is not None) +if not _GenericAliasTypes: + + class _DummyGenericAlias: # type: ignore + pass + + _GenericAliasTypes = (_DummyGenericAlias,) # type: ignore +from collections.abc import Sequence +from .dtypes import AnyDType +from . import dtypes as dt +import tvm.script.ir_builder.tir as tb_tir +from tvm.script.ir_builder import IRBuilder +import torch +import inspect + +_Shapes = TypeVarTuple("_Shapes") +_Shape = ParamSpec("_Shape") +_Stride = ParamSpec("_Stride") +_DType = TypeVar("_DType") + +Scope = Literal["global", "shared.dyn", "local", "local.fragment"] + + +class Annot(ABC): + """ + Base class for tilelang kernel annotations + Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel + + It provides 3 main functionalities: + 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) + 2. parse the argument value into a hash key for jit caching + 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation + """ + + def is_kernel_arg(self) -> bool: + """ + Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) + """ + return False + + @abstractmethod + def with_name(self: Self, name) -> Self: + pass + + @abstractmethod + def get_key_parser(self) -> Callable[[str, Any], tuple[Any, ...]]: + """ + Return a parser function that converts the argument value into a hash key for jit caching + """ + + @abstractmethod + def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable) -> tir.Var | tir.Buffer: + """ + Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation + """ + + def promote(self) -> TIRAnnot | None: + """ + Try to promote the annotation into a FixedAnnot if possible + Return None if not promotable + """ + return None + + +@dataclass +class ArgVarTable: + """ + ArgVarTable is used to manage the mapping from argument names to tir.Var objects + """ + + var_tab: dict[str, tir.Var] = field(default_factory=dict) + tmp_name_idx: int = 0 + + def get_or_create_var(self, name: str, dtype: dt.dtype) -> tir.Var: + if not name: + name = self.create_tmp_name() + if name not in self.var_tab: + self.var_tab[name] = tir.Var(name, dtype) + return self.var_tab[name] + + def create_tmp_name(self) -> str: + name = f"varg_{self.tmp_name_idx}" + self.tmp_name_idx += 1 + return name + + +@dataclass +class Value(Annot): + kind: Literal["static", "dynamic"] = "dynamic" + name: str | None = None + dtype: dt.dtype | None = dt.int32 + value: int | tir.Var | None = None + creator: Callable[[], Any] | None = None + + def is_kernel_arg(self) -> bool: + return self.kind == "dynamic" + + @classmethod + def from_value(cls, value: Any, prefer_name: str = None) -> Value: + if isinstance(value, int): + # handle A: T.Tensor[[1024, 1024], ...] + return Value(kind="static", name=prefer_name, dtype=dt.int32, value=value) + elif isinstance(value, float): + return Value(kind="static", name=prefer_name, dtype=dt.float32, value=value) + elif isinstance(value, dt.dtype): + # handle A: T.float32 + return Value(kind="dynamic", name=prefer_name, dtype=value, value=None) + elif isinstance(value, Value): + # handle A: T.dyn + return value + elif isinstance(value, TypeVar): + return Value(kind="static", name=value.__name__, value=None) + elif isinstance(value, (tir.Var, PrimExpr)): + # handle A: T.Tensor[[M, N, K], ...] + # or primexpr annotation like A: T.Tensor[[M, N * 4 +1]] + name = value.name if isinstance(value, tir.Var) else prefer_name + return Value(kind="dynamic", name=name, dtype=value.dtype, value=value) + elif value is Any or value is None or value is dt.dtype or isinstance(value, (type,) + _GenericAliasTypes): + # A # no annotation + # A: Any + # A: _T + # A: dt.dtype + # A: tuple[...] + return Value(kind="static", name=prefer_name, value=None) + else: + raise TypeError(f"Unsupported Value annotation: {value!r}, type: {type(value)}") + + def with_name(self, name: str) -> Value: + return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value) + + def get_key_parser(self): + if self.kind == "static": + if self.value is not None: + expected_value = self.value + + def key_parser(name: str, target: Any): + assert target == expected_value + return target + + return key_parser + else: + return lambda name, target: (target,) + else: + return lambda name, target: (None,) + + def parse_key(self, target: Any): + return self.get_key_parser()(target) + + def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable, create_arg: bool = True): + if self.kind == "static": + if self.value: + assert self.value == value, f"static value mismatch for {name}: expected {self.value}, got {value}" + return value + else: + name = self.name or name or vt.create_tmp_name() + if self.value is not None: + arg = self.value + elif self.creator is not None: + arg = self.creator() + else: + arg = vt.get_or_create_var(name, self.dtype) + return tb_tir.arg(name, arg) if create_arg else arg + + def __repr__(self): + if self.kind == "static": + if self.value is not None: + return repr(self.value) + else: + return (str(self.name) or "$unnamed") + "$" + else: + if self.value is not None: + return repr(self.value) + elif self.creator is not None: + return repr(self.creator()) + else: + return (str(self.name) or "$unnamed") + "$dyn" + + +def _canonicalize_dtype(val: Any) -> dt.dtype | None: + if val == Any or val is None: + return None + if isinstance(val, TypeVar): + return None + return dt.dtype(val) + + +def _canonicalize_shape(shape: Sequence[Any]) -> list[Value]: + if shape is None or shape is Any: + return None + return [Value.from_value(dim) for _, dim in enumerate(iterable=shape)] + + +def _canonicalize_strides(strides: Sequence[Any]) -> list[Value]: + if strides is None or strides is Any: + return None + return [Value.from_value(dim) for _, dim in enumerate(strides)] + + +def _shape_with_name(shape: Sequence[Value], base_name: str) -> list[Value]: + if shape is None: + return None + res = [] + for i, dim in enumerate(shape): + dim = dim.with_name(f"{base_name}_{i}") + res.append(dim) + return res + + +def _try_convert_static_shape(shape: Sequence[Value]): + if shape is None: + return None + res = [] + for s in shape: + if s.kind == "static" and s.value is not None or s.kind == "dynamic" and s.value is not None: + res.append(s.value) + if len(res) == len(shape): + return res + + +@dataclass +class BufferAnnot(Annot): + shape: tuple = None + strides: tuple = None + dtype: dt.dtype = None + + def is_kernel_arg(self) -> bool: + return True + + @property + def scope(self): + return "global" + + def __call__( + self, + shape: tuple[Unpack[_Shapes]], + dtype: _DType = "float32", + data=None, + strides=None, + elem_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: + return buffer( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope or self.scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def __getitem__(self, params): + shape, dtype = params + if not isinstance(shape, (tuple, list)): + shape = (shape,) + shape = _canonicalize_shape(shape) + dtype = _canonicalize_dtype(dtype) + return self.__class__(shape, strides=self.strides, dtype=dtype) + + def with_name(self, name: str): + shape = _shape_with_name(self.shape, base_name=f"{name}_shape") + strides = _shape_with_name(self.strides, base_name=f"{name}_stride") + return self.__class__(shape, strides, self.dtype) + + def get_key_parser(self): + raw_shapes = True + if self.shape is not None: + raw_shapes = False + shape_len = len(self.shape) + static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == "static"] + # static_fixed_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static' and dim.value is not None] + # static_fixed_shape_values = [dim.value for dim in self.shape if dim.kind == 'static' and dim.value is not None] + raw_strides = True + if self.strides is not None: + raw_strides = False + strides_len = len(self.strides) + strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == "static"] + # static_fixed_strides_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static' and dim.value is not None] + # static_fixed_strides_values = [dim.value for dim in self.strides if dim.kind == 'static' and dim.value is not None] + raw_dtype = True + if self.dtype is not None: + raw_dtype = False + expected_dtype = self.dtype + + def key_parser(name: str, target: Any): + if isinstance(target, torch.Tensor): + shape = tuple(target.shape) + strides = tuple(target.stride()) + dtype = dt.dtype(target.dtype) + elif isinstance(target, tir.Buffer): + shape = tuple(target.shape) + strides = tuple(target.strides) + dtype = dt.dtype(target.dtype) + else: + raise TypeError( + f"Unsupported buffer argument type for argument `{name}`: expected a `torch.Tensor` or `tir.Buffer`, got {type(target)}" + ) + if not raw_shapes: + assert len(shape) == shape_len + shape = tuple(shape[i] for i in static_shape_idx) + # shape_fixed = tuple(shape[i] for i in static_fixed_shape_idx) + # assert shape_fixed == static_fixed_shape_values, f"shape mismatch" + if not raw_strides: + assert len(strides) == strides_len + strides = tuple(strides[i] for i in strides_shape_idx) + # strides_fixed = tuple(strides[i] for i in static_fixed_strides_idx) + # assert strides_fixed == static_fixed_strides_values + if not raw_dtype: + dtype = dt.dtype(dtype) + if dtype != expected_dtype: + raise TypeError(f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}") + return shape, strides, dtype + + return key_parser + + def parse_key(self, target: Any): + return self.get_key_parser()(target) + + @staticmethod + def match_shape(shape: tuple[Value, ...], target_shape: tuple[int, ...], vt: ArgVarTable): + if shape is None: + return target_shape + args = [] + for s, target in zip(shape, target_shape): + args.append(s.create_prim_func_arg(s.name, target, vt, create_arg=False)) + return args + + def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable): + if isinstance(value, tir.Buffer): + shape = value.shape + strides = value.strides + dtype = value.dtype + elif isinstance(value, torch.Tensor): + shape = value.shape + strides = value.stride() + dtype = dt.dtype(value.dtype) + else: + raise TypeError(f"Unsupported buffer argument type: {type(value)}") + shape = self.match_shape(self.shape, shape, vt) + strides = self.match_shape(self.strides, strides, vt) + arg = buffer(shape, dtype=self.dtype or dtype, strides=strides, scope=self.scope) + return tb_tir.arg(name, arg) + + def promote(self): + shape = _try_convert_static_shape(self.shape) + strides = _try_convert_static_shape(self.strides) + if shape is not None and strides is not None and self.dtype is not None: + buf = buffer(shape, self.dtype, strides=strides, scope=self.scope) + return TIRAnnot(data=buf) + + +class TensorAnnot(BufferAnnot): + @staticmethod + def _construct_strides(shape: tuple[Any]): + s, strides = 1, [1] + for dim in shape[:0:-1]: + s *= dim + strides.append(s) + return tuple(reversed(strides)) + + def __call__( + self, + shape: tuple[Unpack[_Shapes]], + dtype: _DType = "float32", + data=None, + strides=None, + elem_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ): + if isinstance(shape, (int, PrimExpr)): + shape = (shape,) + strides = strides or self._construct_strides(shape) + return super().__call__( + shape=shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def promote(self): + shape = _try_convert_static_shape(self.shape) + if shape is not None and self.dtype is not None: + strides = self._construct_strides(shape) + buf = buffer(shape, self.dtype, strides=strides, scope=self.scope) + return TIRAnnot(data=buf) + + +class StridedTensorAnnot(BufferAnnot): + def __call__( + self, + shape, + strides, + dtype: _DType = "float32", + data=None, + elem_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ): + return super().__call__( + shape=shape, + strides=strides, + dtype=dtype, + data=data, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def __getitem__(self, params): + shape, strides, dtype = params + shape = _canonicalize_shape(shape) + strides = _canonicalize_strides(strides) + dtype = _canonicalize_dtype(dtype) + return StridedTensorAnnot(shape, strides, dtype) + + +class FragmentBufferAnnot(BufferAnnot): + @property + def scope(self): + return "local.fragment" + + +class SharedBufferAnnot(BufferAnnot): + @property + def scope(self): + return "shared.dyn" + + +class LocalBufferAnnot(BufferAnnot): + @property + def scope(self): + return "local" + + +class DynAnnot(Value): + """ + Dynamic variable annotation represents a tvm tir.Var argument + """ + + def __call__(self, dtype: AnyDType = dt.float32, name: str | None = None) -> DynAnnot: + return tir.Var(name, dtype) + + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + dtype = None + if len(params) == 1: + (name,) = params + if len(params) == 2: + dtype, name = params + dtype = _canonicalize_dtype(dtype) or dt.int32 + return DynAnnot(kind="dynamic", dtype=dtype, name=name) + + +@dataclass +class DTypeAnnot(Annot): + """ + Data type annotation ensures automatically conversion from AnyDType to dtype + >>> def foo(A: T.dtype): print(A) + >>> foo(torch.float32) + dtype('float32') + >>> foo(T.float32) + dtype('float32') + >>> foo('float32') + dtype('float32') + """ + + name: str | None = None + + def is_kernel_arg(self) -> bool: + return False + + def with_name(self, name): + return DTypeAnnot(name=name) + + def get_key_parser(self): + return lambda name, value: (dt.dtype(value),) + + def create_prim_func_arg(self, name, value, vt): + return dt.dtype(value) + + def __repr__(self): + return self.name + "$dtype" + + +@dataclass +class TIRAnnot(Annot): + """ + TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments + >>> def foo(A: T.Buffer((128,), T.float32)): ... + """ + + data: tir.Buffer | tir.Var + + def is_kernel_arg(self) -> bool: + return True + + def get_key_parser(self): + return lambda name, value: (None,) + + def create_prim_func_arg(self, name, value, vt): + return tb_tir.arg(name, self.data) + + def with_name(self, name: str): + IRBuilder.name(name, self.data) + return self + + def __repr__(self): + return repr(self.data) + + +if TYPE_CHECKING: + + class Buffer(Generic[_Shape, _DType]): + def __init__( + shape: tuple[Unpack[_Shapes]], + dtype: _DType = "float32", + data=None, + strides=None, + elem_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]: ... + + @property + def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]: ... + + @property + def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]: ... + + @property + def strides(self) -> tuple[tir.PrimExpr]: ... + + def scope(self) -> Scope: ... + + class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]): + def __new__( + shape: tuple[Unpack[_Shapes]], + dtype: _DType = "float32", + data=None, + strides=None, + elem_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... + + class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]): + def __new__( + shape: tuple[Unpack[_Shapes]], + strides=None, + dtype: _DType = "float32", + data=None, + elem_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... + + class FragmentBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]): + pass + + class LocalBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]): + pass + + class SharedBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]): + pass + + class dyn(tir.Var): + def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]: ... + + @property + def dtype(self: dyn[_DType]) -> dt.dtype[_DType]: ... + +else: + Buffer = BufferAnnot() + Tensor = TensorAnnot() + StridedTensor = StridedTensorAnnot() + FragmentBuffer = FragmentBufferAnnot() + SharedBuffer = SharedBufferAnnot() + LocalBuffer = LocalBufferAnnot() + dyn = DynAnnot() + + +@dataclass +class FuncAnnot: + sig: inspect.Signature + arg_names: list[str] + annots: dict[str, Annot] + arg_parser: dict[str, Callable[[Any], tuple[Any, ...]]] + ker_arg_names: list[str] + + @classmethod + def from_sig_annots(cls, sig: inspect.Signature, func_annots: dict[str, Any]) -> FuncAnnot: + annots = {} + arg_parser = {} + ker_arg_names = [] + for param in sig.parameters.values(): + name = param.name + annot = func_annots.get(name, Value("static", name)) + if not isinstance(annot, Annot): + if not isinstance(annot, type) and callable(annot): + annot = annot() + if annot is dt.dtype: + annot = DTypeAnnot(name=name) + elif isinstance(annot, (tir.Buffer, tir.Var)): + annot = TIRAnnot(data=annot) + else: + annot = Value(kind="static", name=name) + annot = annot.promote() or annot + annots[name] = annot.with_name(name) + if annot.is_kernel_arg(): + ker_arg_names.append(name) + arg_parser[name] = annot.get_key_parser() + arg_names = list(sig.parameters.keys()) + return FuncAnnot(sig, arg_names, annots, arg_parser, ker_arg_names) + + def parse_key(self, *args, **kws): + """ + Parse arguments and generates the cache key for jit caching + """ + args = {name: arg for name, arg in zip(self.arg_names, args)} + arg_dict = dict(**args, **kws) + parsed = [] + for name, value in arg_dict.items(): + key = self.arg_parser[name](name, value) + parsed.append((name, key)) + return tuple(sorted(parsed)) + + def convert_to_kernel_args(self, *args, **kws): + args = {name: arg for name, arg in zip(self.arg_names, args)} + arg_dict = dict(**args, **kws) + return [arg_dict[name] for name in self.ker_arg_names] + + def create_argument(self, name: str, value: Any, vt: ArgVarTable): + """ + Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation + """ + return self.annots[name].create_prim_func_arg(name, value, vt) + + def is_all_static(self): + """ + Check if all arguments are static (i.e., can be fully determined at compile time) + """ + return all(isinstance(annot, TIRAnnot) for annot in self.annots.values()) + + def get_all_static_args(self): + res = {} + for name, annot in self.annots.items(): + if isinstance(annot, TIRAnnot): + res[name] = annot.data + return res + + def get_compile_time_unknown_args(self): + res = [] + for name, annot in self.annots.items(): + if not isinstance(annot, TIRAnnot): + res.append(name) + return res diff --git a/tilelang/original/tilelang/language/v2/ast.py b/tilelang/original/tilelang/language/v2/ast.py new file mode 100644 index 0000000000000000000000000000000000000000..26c1851ebcee147c8a4ad841e004964f9511b9b3 --- /dev/null +++ b/tilelang/original/tilelang/language/v2/ast.py @@ -0,0 +1,595 @@ +from __future__ import annotations +import ast +from dataclasses import dataclass +from typing import Callable, Generic, Any, Literal, TypeVar +from contextlib import AbstractContextManager +from collections.abc import Iterable + +# Python 3.9 compatibility for ParamSpec +try: + from typing import ParamSpec +except ImportError: # Python < 3.10 + from typing_extensions import ParamSpec +import inspect + +# from .utils import get_ast, get_compiled_object +from . import utils + +_span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + +def ast_has_span(ast: ast.AST) -> bool: + return all(hasattr(ast, attr) for attr in _span_attrs) + + +def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]: + if not ast_has_span(ast): + return None + return tuple(getattr(ast, attr) for attr in _span_attrs) + + +def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]): + if not ast_has_span(ast): + return + for attr, value in zip(_span_attrs, span): + setattr(ast, attr, value) + + +class QuoteVisitor(ast.NodeTransformer): + def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None): + self.names = names + self.passes = passes or [] + self.span = span + + def generic_visit(self, node: ast.AST): + if self.span is not None: + ast_set_span(node, self.span) + return super().generic_visit(node) + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in self.names: + return self.names[node.id] + else: + return node + + def visit_Pass(self, node: ast.Pass) -> Any: + item = self.passes.pop(0) + return item if item else node + + +def quote(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> list[ast.AST]: + tree = ast.parse(expr) + if isinstance(span, ast.AST): + span = ast_get_span(span) + tree = QuoteVisitor(kws, passes, span).visit(tree) + return tree.body + + +def quote1(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> ast.AST: + res = quote(expr, passes=passes, span=span, **kws) + assert len(res) == 1 + return res[0] + + +def quote_expr(expr: str, **kws) -> ast.expr: + res = quote1(expr, **kws) + assert isinstance(res, ast.Expr) + return res.value + + +Operator = Literal["Add", "Sub", "Mult", "MatMult", "Div", "Mod", "Pow", "LShift", "RShift", "BitOr", "BitXor", "BitAnd", "FloorDiv"] +BoolOp = Literal["And", "Or", "Not"] + + +def get_operator_name(operator: ast.operator) -> Operator: + return operator.__class__.__name__ + + +def get_boolop_name(boolop: ast.boolop) -> BoolOp: + return boolop.__class__.__name__ + + +_T = TypeVar("_T") + + +def eval_op(op: Operator, left: Any, right: Any) -> Any: + if op == "Add": + return left + right + if op == "Sub": + return left - right + if op == "Mult": + return left * right + if op == "MatMult": + return left @ right + if op == "Div": + return left / right + if op == "Mod": + return left % right + if op == "Pow": + return left**right + if op == "LShift": + return left << right + if op == "RShift": + return left >> right + if op == "BitOr": + return left | right + if op == "BitXor": + return left ^ right + if op == "BitAnd": + return left & right + if op == "FloorDiv": + return left // right + raise ValueError(f"Unknown operator: {op}") + + +def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any: + if op == "Add": + left[sl] += right + return left + if op == "Sub": + left[sl] -= right + return left + if op == "Mult": + left[sl] *= right + return left + if op == "MatMult": + left[sl] @= right + return left + if op == "Div": + left[sl] /= right + return left + if op == "Mod": + left[sl] %= right + return left + if op == "Pow": + left[sl] **= right + return left + if op == "LShift": + left[sl] <<= right + return left + if op == "RShift": + left[sl] >>= right + return left + if op == "BitOr": + left[sl] |= right + return left + if op == "BitXor": + left[sl] ^= right + return left + if op == "BitAnd": + left[sl] &= right + return left + if op == "FloorDiv": + left[sl] //= right + return left + raise ValueError(f"Unknown operator: {op}") + + +class _empty: ... + + +class BaseBuilder: + empty = _empty + + def get_parent_locals(self): + return inspect.currentframe().f_back.f_back.f_locals + + def ctx_if(self, cond) -> Iterable[_T]: + yield cond + + def ctx_then(self, val: _T) -> Iterable[None]: + if val: + yield + + def ctx_else(self, val: _T) -> Iterable[None]: + if not val: + yield + + def eval(self, val: Any): # noqa: B027 + pass + + def ctx_for(self, range: Iterable[Any]) -> Iterable[Any]: + return range + + def ctx_continue(self) -> bool: + return True + + def ctx_break(self) -> bool: + return True + + def ctx_while(self, cond: Callable[[], Any]) -> Iterable[None]: + while cond(): + yield + + def bind(self, name: str, value: Any, annot: Any = empty) -> Any: + return value + + def unwrap_value(self, value): + return value + + def assign_slice(self, lval: Any, sl: slice, value: Any, annot: Any = empty): + lval[sl] = value + + def aug_assign(self, op: Operator, target: Any, aug_value: Any) -> Any: + return eval_op(op, target, aug_value) + + def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any): + eval_aug_assign(op, target, sl, aug_value) + + def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any] | None = None) -> Any: + if op == "And": + return left and right() + if op == "Or": + return left or right() + if op == "Not": + return not left + raise ValueError(f"Unknown boolop: {op}") + + def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any: + return then() if cond else otherwise() + + def ret(self, value: Any) -> Any: + return value + + def ctx_with(self, ctx: AbstractContextManager[Any]) -> AbstractContextManager[Any]: + return ctx + + def assert_expr(self, cond: Any, msg: Any): + assert cond, msg + + def rval(self, name: str, value: Any): + return value + + def arg(self, name: str, value: Any): + return value + + def override(self, name: str): + return globals()[name] + + +class DSLMutator(ast.NodeTransformer): + def __init__(self, closure_names: list[str]): + self.tmp_counter = 0 + self.closure_names = closure_names + + def get_tmp(self) -> str: + name = f"__{self.tmp_counter}" + self.tmp_counter += 1 + return name + + def visit_If(self, node: ast.If): + node = self.generic_visit(node) + br = self.get_tmp() + if len(node.orelse) == 0: + return quote( + f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n", + cond=node.test, + passes=[node.body], + span=node, + ) + return quote( + f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n for _ in __tb.ctx_else({br}):\n pass\n", + cond=node.test, + passes=[node.body, node.orelse], + span=node, + ) + + def visit_Expr(self, node: ast.Expr): + node = self.generic_visit(node) + return quote("__tb.eval(value)", value=node.value, span=node) + + def _parse_names(self, target: ast.expr): + if isinstance(target, ast.Name): + return f"'{target.id}'" + elif isinstance(target, ast.Tuple): + return "(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)" + else: + s = ast.unparse(target) + raise NotImplementedError(f"Unsupported for target `{s}`") + + def visit_For(self, node: ast.For): + node = self.generic_visit(node) + tmp = self.get_tmp() + # names = self._parse_names(node.target) + var = ast.Name(tmp, ctx=ast.Load()) + ast_set_span(var, ast_get_span(node.target)) + stmts = self._emit_assign_target(node.target, var) + return quote( + f"for {tmp} in __tb.ctx_for(range):\n pass\n", + target=node.target, + range=node.iter, + passes=[stmts + node.body], + span=node, + ) + + def visit_Continue(self, node: ast.Continue): + node = self.generic_visit(node) + return quote("if __tb.ctx_continue(): continue", span=node) + + def visit_Break(self, node: ast.Break): + node = self.generic_visit(node) + return quote("if __tb.ctx_break(): break", span=node) + + def _emit_assign_target(self, target: ast.expr, rval: ast.expr, annot: ast.expr = None) -> list[ast.AST]: + if isinstance(target, ast.Name): + if annot is None: + return quote(f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) + else: + return quote(f'name = __tb.bind("{target.id}", value, annot)', name=target, value=rval, annot=annot, span=target) + elif isinstance(target, ast.Attribute): + s = ast.unparse(target) + raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`") + elif isinstance(target, ast.Subscript): + if annot is None: + return quote( + "__tb.assign_slice(lval, slice, value)", + lval=target.value, + slice=target.slice, + value=rval, + span=target, + ) + else: + return quote( + "__tb.assign_slice(lval, slice, value, annot)", + lval=target.value, + slice=target.slice, + value=rval, + annot=annot, + span=target, + ) + else: + # flatten nested tuple into a list of (tmp_name, target) + unpacked = [] + + def _visit_target(target: ast.expr) -> str: + if isinstance(target, (ast.Name, ast.Subscript)): + tmp = self.get_tmp() + unpacked.append((tmp, target)) + res = ast.Name(id=tmp, ctx=target.ctx) + ast_set_span(res, ast_get_span(target)) + return res + elif isinstance(target, ast.Tuple): + elts = [_visit_target(elt) for elt in target.elts] + res = ast.Tuple(elts=elts, ctx=target.ctx) + ast_set_span(res, ast_get_span(target)) + return res + else: + s = ast.unparse(target) + raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`") + + unpack_stmt = ast.Assign(targets=[_visit_target(target)], value=quote_expr("__tb.unwrap_value(rval)", rval=rval, span=rval)) + ast_set_span(unpack_stmt, ast_get_span(target)) + stmts = [unpack_stmt] + bind_lvals = [] + bind_rvals = [] + + def flush_binds(): + if bind_lvals: + stmts.append(quote1(f"{', '.join(bind_lvals)}, = {', '.join(bind_rvals)},", span=target)) + bind_lvals.clear() + bind_rvals.clear() + + # the following code generate two phase binding to support swap like semantics + # for example: + # a, b = b, a + # 1 phase: + # _tmp_0, _tmp_1 = b, a + # => _tmp_0: T.int32 = b + # => _tmp_1: T.int32 = a + # 2 phase: + # a, b = _tmp_0, _tmp_1 + # => a = _tmp_0 => a[0] = _tmp_0 + # => b = _tmp_1 => b[0] = _tmp_1 + + # 1 phase: _tmp_0, _tmp_1 = __tb.bind('_', a), __tb.bind('_', b) + for tmp, _target in unpacked: + bind_lvals.append(tmp) + bind_rvals.append(f'__tb.bind("_", {tmp})') + + flush_binds() + + # 2 phase: a, b = __tb.bind('a', _tmp_0), __tb.bind('b', _tmp_1) + for tmp, target in unpacked: + if isinstance(target, ast.Name): + bind_lvals.append(target.id) + bind_rvals.append(f'__tb.bind("{target.id}", {tmp})') + elif isinstance(target, ast.Subscript): + flush_binds() + stmts.append(quote1(f"__tb.assign_slice(lval, slice, {tmp})", lval=target.value, slice=target.slice, span=target)) + else: + s = ast.unparse(target) + raise NotImplementedError(f"Unsupported target: {s}") + flush_binds() + return stmts + + def visit_Assign(self, node: ast.Assign) -> list[ast.AST]: + node = self.generic_visit(node) + rval = node.value + if len(node.targets) == 1: + return self._emit_assign_target(node.targets[0], rval) + else: + tmp_name = self.get_tmp() + tmp_store = ast.Name(tmp_name, ctx=ast.Store()) + tmp_load = ast.Name(tmp_name, ctx=ast.Load()) + ast_set_span(tmp_store, node.targets[0]) + ast_set_span(tmp_load, node.targets[0]) + stmt = self._emit_assign_target(tmp_store, rval) + for target in node.targets: + stmt.extend(self._emit_assign_target(target, tmp_load)) + return stmt + + def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]: + node = self.generic_visit(node) + target, rval = node.target, node.value + op = get_operator_name(node.op) + if isinstance(target, ast.Name): + return quote(f"name = __tb.aug_assign('{op}', {target.id}, value)", name=target, value=rval, span=node) + elif isinstance(target, ast.Subscript): + return quote( + f"__tb.aug_assign_slice('{op}', lval, slice, value)", + lval=target.value, + slice=target.slice, + value=rval, + span=node, + ) + else: + return node + + def visit_AnnAssign(self, node: ast.AnnAssign): + node = self.generic_visit(node) + rval = node.value or quote_expr("__tb.empty", span=node, annot=node) + return self._emit_assign_target(node.target, rval, annot=node.annotation) + + def visit_While(self, node): + node = self.generic_visit(node) + return quote1("for _ in __tb.ctx_while(lambda: cond):\n pass", cond=node.test, passes=[node.body], span=node) + + def visit_FunctionDef(self, node: ast.FunctionDef): + node = self.generic_visit(node) + all_args = node.args.posonlyargs + node.args.args + if node.args.vararg is not None: + all_args += node.args.vararg + all_args += node.args.kwonlyargs + stmts = [] + for arg in all_args: + name = arg.arg + if arg.annotation is not None: + arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) + else: + arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) + arg.annotation = None + stmts.append(arg_stmt) + node.body = stmts + node.body + node.decorator_list.clear() + return quote1( + f"def make_closure({', '.join(self.closure_names)}):\n" + f" def {node.name}(__tb):\n" + " range = __tb.override('range')\n" + " pass\n" + f" return {node.name}\n" + f" return {node.name}", + passes=[node], + ) + + def visit_BoolOp(self, node: ast.BoolOp): + node = self.generic_visit(node) + op_name = get_boolop_name(node.op) + last = node.values[-1] + for i in reversed(range(len(node.values) - 1)): + last = quote_expr( + expr=f"__tb.boolop('{op_name}', left, lambda: right)", + left=node.values[i], + right=last, + span=node, + ) + return last + + def visit_UnaryOp(self, node: ast.UnaryOp): + node = self.generic_visit(node) + if isinstance(node.op, ast.Not): + return quote_expr("__tb.boolop('Not', operand)", operand=node.operand, span=node) + return node + + def visit_Compare(self, node: ast.Compare) -> ast.expr: + node = self.generic_visit(node) + left = node.left + split = [] + for op, comp in zip(node.ops, node.comparators): + cmp = ast.Compare(left=left, ops=[op], comparators=[comp]) + ast_set_span(cmp, ast_get_span(node)) + split.append(cmp) + left = comp + last = split[-1] + for i in reversed(range(len(split) - 1)): + last = quote_expr("__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node) + return last + + def visit_IfExp(self, node: ast.IfExp) -> ast.Expr: + node = self.generic_visit(node) + return quote_expr( + "__tb.ifexp(cond, lambda: then, lambda: otherwise)", cond=node.test, then=node.body, otherwise=node.orelse, span=node + ) + + def visit_Return(self, node: ast.Return): + node = self.generic_visit(node) + return quote("return __tb.ret(value)", value=node.value, span=node) + + def visit_With(self, node: ast.With): + node = self.generic_visit(node) + for expr in node.items: + expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr) + return node + + def visit_Assert(self, node: ast.Assert): + node = self.generic_visit(node) + return quote("__tb.assert_expr(cond, msg)", cond=node.test, msg=node.msg, span=node) + + def visit_Name(self, node: ast.Name): + if isinstance(node.ctx, ast.Load): + return quote_expr(f"__tb.rval('{node.id}', node)", node=node, span=node) + return node + + +_P = ParamSpec("_P") + + +@dataclass +class IRGenerator(Generic[_P, _T]): + gen: Callable[[BaseBuilder], Callable[_P, _T]] + source: str + + +def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: + """ + Transform a Python function into an IR (Intermediate Representation) generator. + This function takes a regular Python function and performs AST (Abstract Syntax Tree) + transformation to create an IRGenerator that can be used for code generation purposes. + Args: + func (Callable[_P, _T]): The Python function to be transformed. This should be a + callable that will be analyzed and mutated at the AST level. The function's + signature is preserved through generic type parameters _P (parameters) and + _T (return type). + Returns: + IRGenerator[_P, _T]: An IRGenerator instance wrapping the transformed function. + The generator contains: + - gen: The compiled and mutated version of the original function + - source: The unparsed source code of the transformed AST as a string + Example: + >>> @mutate + ... def my_function(x: int) -> int: + ... return x * 2 + >>> # my_function is now an IRGenerator that can be used for code generation + Note: + - The original function's closure variables and captured context are preserved + - The transformation is performed at compile-time through AST manipulation + - The returned IRGenerator maintains type information from the original function + """ + + tree = utils.get_ast(func) + filename = inspect.getsourcefile(func) or inspect.getfile(func) + nonlocals = utils.get_func_nonlocals(func) + + # DSLMutator generates a function named `make_closure` + # it accepts all names inside nonlocal, and returns the mutated function + # this is because we must separate the closure namespace form the global namespace + # if we directly inject closure variables into the global namespace, + # it generates a new `globals` dict, and the dict owns all reference to the original globalns + # which makes memory leak, because the original globalns cannot be freed + # ```py + # a = 123 + # def foo(): + # x = foo.__globals__ # OK, globals are maintained by python + # x = {**foo.__globals__, } # Not OK: globals are copied, and the original globals cannot be freed + # def bar(): x + # return bar + # ``` + tree = DSLMutator(nonlocals.keys()).visit(tree) + + make_closure = utils.get_compiled_object( + tree, + "make_closure", + filename, + func.__globals__, # use the original globalns + ) + fn = make_closure(**nonlocals) + return IRGenerator(gen=fn, source=ast.unparse(tree)) diff --git a/tilelang/original/tilelang/language/v2/builder.py b/tilelang/original/tilelang/language/v2/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8930a1b34f1e2b421576056f532e9b270ddcec --- /dev/null +++ b/tilelang/original/tilelang/language/v2/builder.py @@ -0,0 +1,884 @@ +from __future__ import annotations +from contextlib import contextmanager, AbstractContextManager +from dataclasses import dataclass +import inspect + +from tilelang.language.kernel import KernelLaunchFrame +from tvm_ffi.container import Map +from tvm.ir.base import Span +from tvm.ir.expr import Range +from tvm.tir.stmt import BufferRegion +from .ast import BaseBuilder, IRGenerator, eval_op, mutate +from .utils import construct_strides +import tvm +from tvm.tir import Buffer +from tvm.script.ir_builder import tir, IRBuilder + +from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var +from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union +from collections.abc import Sequence +from .annot import FuncAnnot, ArgVarTable, Annot +import pprint + +# Python 3.9 compatibility for ParamSpec and Self +try: + from typing import ParamSpec, Self +except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec + from typing_extensions import ParamSpec, Self +from . import dtypes as dt +from . import utils +import threading +import logging + +logger = logging.getLogger(__name__) + + +def unwrap_expr(expr) -> PrimExpr | int | float: + """ + unwrap expr and convert it into PrimExpr like + """ + if isinstance(expr, tir.meta_var): + expr = expr.value + elif isinstance(expr, Ref): + return expr.load() + elif is_var(expr): + expr = tir.BufferLoad(expr, indices=[0]) + elif isinstance(expr, (EqualOp, NotEqualOp)): + expr = expr.asobject() + return expr + + +def unwrap_cond(expr): + """ + unwrap expr and convert to bool condition + """ + expr = unwrap_expr(expr) + if isinstance(expr, (IntImm, FloatImm, StringImm)): + return bool(expr.value) + elif isinstance(expr, PrimExpr): + return expr + elif isinstance(expr, Buffer): + raise TypeError(f"Buffer `{expr}` cannot be used as condition directly.") + elif isinstance(expr, (int, bool)) or expr is None: + return bool(expr) + else: + logger.warning( + f"Python expression `{expr}` is used as condition in TileLang, \nthis is treated as a constant expression. ", + stack_info=True, + stacklevel=3, + ) + return bool(expr) + + +thread_local_storage = threading.local() + + +class Frame: + """ + Frame are virtual context managers used in frontend only + They do not have any runtime representation in the generated TIR. + """ + + def __enter__(self): ... + + def __exit__(self, exc_type, exc_value, traceback): ... + + +class MacroFrame(Frame): ... + + +class ExitedMacroFrame(Frame): ... + + +class BoolOpFrame(Frame): ... + + +class ConstIfFrame(Frame): ... + + +class BlockFrame(Frame): ... + + +class ContinueFrame(Frame): ... + + +class BreakFrame(Frame): ... + + +@dataclass +class SerialForWithStep: + start: PrimExpr + stop: PrimExpr + step: PrimExpr + annotations: dict[str, Any] | None = None + + +@dataclass +class OutTensor: + shape: Sequence[PrimExpr] + dtype: dt.dtype + + @property + def strides(self): + return construct_strides(tuple(self.shape)) + + +@dataclass +class Ref: + bufload: BufferLoad + + @property + def buffer(self): + return self.bufload.buffer + + def store(self, value): + tir.buffer_store(self.bufload.buffer, value, self.bufload.indices) + + def load(self): + return self.bufload + + +class UnrollForWithStep(SerialForWithStep): ... + + +# Python 3.9 compatibility: avoid PEP 604 unions at runtime +# Use tuple for isinstance checks and typing.Union for annotations/aliases +ContinueOrBreak = (ContinueFrame, BreakFrame) +AnyFrame = Union[tir.frame.IRBuilderFrame, Frame] + +TIR_CONTROL_FRAME = ( + tir.frame.WhileFrame, + tir.frame.ForFrame, + tir.frame.IfFrame, + tir.frame.PrimFuncFrame, +) + +TIR_VAR_SCOPE_FRAME = ( + tir.frame.WhileFrame, + tir.frame.ForFrame, + tir.frame.IfFrame, + tir.frame.PrimFuncFrame, + MacroFrame, + KernelLaunchFrame, +) + + +def is_var(v: Any) -> bool: + return isinstance(v, Buffer) and v.scope() == "local.var" + + +class Builder(BaseBuilder): + def __init__(self, func_annot: FuncAnnot = None): + self.frames: list[AnyFrame] = [] + self.ir_builder = IRBuilder() + self.name_inside_frame: dict[str, AnyFrame] = {} + self.macro_arg_annot = {} + self.func_annot = func_annot + self.out_idx = [] + self.out_tensor_cnt = 0 + self.arg_vt = ArgVarTable() + + @classmethod + def current(cls) -> Self: + builder = getattr(thread_local_storage, "builder", None) + return builder + + @contextmanager + def prim_func(self, name): + thread_local_storage.builder = self + with self.ir_builder, self.with_frame(tir.prim_func()): + tir.func_name(name) + yield + if len(self.out_idx) != self.out_tensor_cnt: + raise RuntimeError("Not all tensor allocated from `T.empty` are returned") + + @contextmanager + def macro(self, name=None, annotations=None): + if self.find_frame_idx(BoolOpFrame) is not None: + raise RuntimeError( + f"Macro `{name}` is used inside boolean expressions, " + "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs" + ) + save = self.name_inside_frame, self.macro_arg_annot + self.name_inside_frame = {} + self.macro_arg_annot = annotations or {} + pos = len(self.frames) + # here we add a ExitedMacroFrame to preserve the frame stack inside macro + # because macro may bind some variable, and return it + # + # ```py + # @T.macro + # def foo(x): + # y = x + 1 + # return y + # @T.prim_func + # def bar(): + # c = foo(1) # macro generates let y = x + 1 + # d = c # d = c should lay inside frame of `let y = x + 1` + self.frames.append(MacroFrame()) + yield + self.frames[pos] = ExitedMacroFrame() + self.name_inside_frame, self.macro_arg_annot = save + + def get(self): + return self.ir_builder.get() + + def find_frame_idx(self, frame: type | tuple[type, ...], start=0) -> int | None: + for idx in reversed(range(start, len(self.frames))): + f = self.frames[idx] + if isinstance(f, frame): + return idx + + def enter_frame(self, frame: AbstractContextManager[Any]): + self.frames.append(frame) + return frame.__enter__() + + def check_continue_break(self): + idx = self.find_frame_idx(ContinueOrBreak) + if idx is not None: + logger.warning("Writing code after continue/break may cause undefined behavior in tilelang.", stack_info=True, stacklevel=3) + + @contextmanager + def with_frame(self, frame: AbstractContextManager[Any] | None): + pop_idx = len(self.frames) + yield self.enter_frame(frame) + while len(self.frames) > pop_idx: + self.frames.pop().__exit__(None, None, None) + + class _has_if_frame: ... + + def ctx_if(self, cond): + self.check_continue_break() + cond = unwrap_cond(cond) + if isinstance(cond, PrimExpr): + with self.with_frame(tir.If(cond)): + yield self._has_if_frame + else: + with self.with_frame(ConstIfFrame()): + yield cond + + def ctx_then(self, val): + if val is self._has_if_frame: + with self.with_frame(tir.Then()): + yield + else: + with self.with_frame(BlockFrame()): + if val: + yield + + def ctx_else(self, val): + if val is self._has_if_frame: + with self.with_frame(tir.Else()): + yield + else: + with self.with_frame(BlockFrame()): + if not val: + yield + + def eval(self, val: Any): + val = unwrap_expr(val) + if val is None: + pass + elif isinstance(val, tir.frame.IRBuilderFrame): + if isinstance(val, tir.frame.ForFrame): + logger.warning( + "Evaluating a for frame may cause undefined behavior in tilelang.", + stack_info=True, + stacklevel=1, + ) + self.enter_frame(val) + elif isinstance(val, PrimExpr): + tir.evaluate(val) + elif isinstance(val, (int, bool)): + tir.evaluate(tvm.tir.const(val)) + elif isinstance(val, str): + pass + elif isinstance(val, tvm.tir.stmt.BufferStore): + tir.buffer_store(val.buffer, val.value, val.indices, val.predicate) + elif isinstance(val, (Buffer, Var)): + pass + else: + logger.warning(f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2) + + def ctx_for(self, it): + self.check_continue_break() + it = unwrap_expr(it) + if isinstance(it, (SerialForWithStep, UnrollForWithStep)): + # Validate and compute the trip count before constructing the frame + if isinstance(it.step, (int, IntImm)): + step_value = it.step if isinstance(it.step, int) else it.step.value + if step_value == 0: + raise ValueError("Invalid stepped serial: step must be non-zero") + if step_value > 0: + real_stop = tir.ceildiv(it.stop - it.start, step_value) + else: + real_stop = tir.ceildiv(it.start - it.stop, -step_value) + else: + logger.warning(f"Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang") + real_stop = tir.ceildiv(it.stop - it.start, it.step) + if isinstance(it, UnrollForWithStep): + real_frame = tir.unroll(real_stop, annotations=it.annotations) + elif isinstance(it, SerialForWithStep): + real_frame = tir.serial(real_stop, annotations=it.annotations) + else: + raise TypeError( + f"Invalid for loop, got {it}({type(it)}), expect one of the following: " + "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding" + ) + with self.with_frame(real_frame) as v: + IRBuilder.name("_tmp", v) + yield it.start + v * it.step + else: + if not isinstance(it, tir.frame.ForFrame): + raise TypeError( + f"Invalid for loop, got {it}({type(it)}), expect one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding" + ) + with self.with_frame(it) as v: + yield v + + def ctx_continue(self): + self.check_continue_break() + # add a dummy frame for checking code after continue/break + self.enter_frame(ContinueFrame()) + tir.evaluate(tir.continue_loop()) + + def ctx_break(self): + self.check_continue_break() + # add a dummy frame for checking code after continue/break + self.enter_frame(BreakFrame()) + tir.evaluate(tir.break_loop()) + + def ctx_while(self, cond): + self.check_continue_break() + cond_v = cond() + cond_v_unwrap = unwrap_cond(cond_v) + if not isinstance(cond_v_unwrap, PrimExpr): + if cond_v_unwrap: + raise RuntimeError( + f"Infinite while loop detected in TileLang\n" + f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n" + ) + else: + logger.warning( + "While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n", + f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n", + stack_info=True, + stacklevel=2, + ) + with self.with_frame(tir.While(cond_v_unwrap)): + yield None + + def bind(self, name, value, annot=BaseBuilder.empty): + self.check_continue_break() + locals = self.get_parent_locals() + orig_value = locals.get(name, None) + # if orig_value is a local.var, we use buffer_store to modify it immutably + # however, if rvalue is not a PrimExpr, such as buffer, + # we should not use buffer_store, and bind it instead + # ```py + # a = tl.alloc_var('float32') # bind var `a` + # a = tl.alloc_var('float32') # bind a new var `a_1` + # a = tl.alloc_shared((1,), T.float32) # bind a to new buffer + # b = a # get value of var `b = a_1[0]`` + # c = tl.alloc_var('float32') # bind var `c` + # c = a # get and assign `c[0] = a_1[0]` + # ``` + if isinstance(orig_value, Ref) and isinstance(value, (int, float, PrimExpr)): + orig_value.store(value) + return orig_value + if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)): + tir.buffer_store(orig_value, value, 0) + return orig_value + + # 2. Quick return for trivil types + if isinstance(value, (tuple, list, tvm.ffi.Array, int, float, str)): + return value + if isinstance(value, tir.IntImm) and value.dtype == "int32": + return value.value + if isinstance(value, (Var, Buffer)): + # Bind TVM Var/Buffer names and also record scope so reusing the same + # Python name (e.g., loop vars like `i`) across different for-frames + # works without triggering out-of-scope errors. + IRBuilder.name(name, value) + if name != "_": + frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) + assert frame is not None, f"Variable `{name}` is not defined inside any control flow." + self.name_inside_frame[name] = self.frames[frame] + return value + + # 3. Bind immutable tilelang objects + res = self.bind_immutable(name, value) + + # 4. Check variable scope and shadowing + if name != "_": + frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) + assert frame is not None, f"Variable `{name}` is not defined inside any control flow." + if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: + logger.warning( + f"Variable `{name}` is declared twice, are you looking for a T.alloc_var?", + stack_info=True, + stacklevel=2, + ) + self.name_inside_frame[name] = self.frames[frame] + return res + + def unwrap_value(self, value): + """ + Unwrap some tilelang objects to get their inner value + """ + value = unwrap_expr(value) + # handle bx, by = tl.Kernel(128, 128), rval is frame + if isinstance(value, tir.frame.IRBuilderFrame): + return self.enter_frame(value) + else: + return value + + def bind_immutable(self, name, value): + """ + Bind an immutable tilelang objects. + The immutability means the result is usually not changed or re-assigned in a python block. + """ + if name == "_": + # use _tmp to make the generated tir more readable + name = "_tmp" + if isinstance(value, tir.meta_var): + return value.value + elif isinstance(value, tir.frame.IRBuilderFrame): + if isinstance(value, tir.frame.ForFrame): + logger.warning( + "Binding a for frame to variable may cause undefined behavior in tilelang.", + stack_info=True, + stacklevel=2, + ) + return self.enter_frame(value) + elif isinstance(value, OutTensor): + arg = tir.arg( + name, + tir.buffer( + shape=value.shape, + dtype=value.dtype, + strides=value.strides, + ), + ) + arg._out_idx = self.out_tensor_cnt + self.out_tensor_cnt += 1 + return arg + elif isinstance(value, (Buffer, tir.IterVar, tir.Var)): + IRBuilder.name(name, value) + return value + else: + try: + value = tvm.runtime.convert(value) + except TypeError: + return value + frame = tir.LetStmt(value) + var = frame.var + IRBuilder.name(name, var) + return self.enter_frame(frame) + + def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty): + self.check_continue_break() + if annot is not self.empty: + logger.warning("Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) + if isinstance(lval, Buffer): + tir.buffer_store(lval, value, sl) + else: + return super().assign_slice(lval, sl, value) + + def aug_assign(self, op, target, aug_value): + self.check_continue_break() + if isinstance(target, Ref): + target.store(eval_op(op, target.bufload, aug_value)) + return target + elif is_var(target): + tir.buffer_store(target, eval_op(op, target[0], aug_value), 0) + return target + elif isinstance(target, Buffer): + raise RuntimeError("Augmented assignment is not supported for Buffer") + else: + return super().aug_assign(op, target, aug_value) + + def aug_assign_slice(self, op, target, sl, aug_value): + self.check_continue_break() + if isinstance(target, Buffer): + tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl) + else: + return super().aug_assign_slice(op, target, sl, aug_value) + + def boolop(self, op, left, right=None): + left = unwrap_cond(left) + if isinstance(left, PrimExpr): + with self.with_frame(BoolOpFrame()): + if op == "And": + return tir.And(left, right()) + if op == "Or": + return tir.Or(left, right()) + if op == "Not": + return tir.Not(left) + raise RuntimeError(f"Unsupported boolean operator: {op}") + else: + return super().boolop(op, left, right) + + def ifexp(self, cond, then, otherwise): + cond = unwrap_cond(cond) + if isinstance(cond, PrimExpr): + with self.with_frame(BoolOpFrame()): + return tir.if_then_else(cond, then(), otherwise()) + else: + return super().ifexp(cond, then, otherwise) + + def ret(self, value=None): + self.check_continue_break() + # handle return T.alloc_var() + if value is None: + value = tuple() + elif isinstance(value, tuple): + value = tuple(self.unwrap_value(v) for v in value) + else: + value = self.unwrap_value(value) + last_macro = self.find_frame_idx(MacroFrame) + if last_macro is not None: + frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro) + if frame is not None: + raise NotImplementedError( + "Return from control flow is not supported yet. \n" + "You should allocate a var before the control flow, assign value inside the blocks, \n" + "and return the var after the control flow. i.e.\n" + "```\n" + "@T.macro\n" + "def my_macro(cond):\n" + " a = T.alloc_var(T.float16)\n" + " if cond:\n" + " a = 1.0\n" + " return a\n" + "```" + ) + return value + else: + if not isinstance(value, tuple): + value = (value,) + for v in value: + if not isinstance(v, Buffer) or not hasattr(v, "_out_idx"): + raise RuntimeError(f"Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})") + # convert 0, 1, 2 => -3, -2, -1 as the out tensor index + self.out_idx.append(v._out_idx - self.out_tensor_cnt) + if len(self.out_idx) != self.out_tensor_cnt: + raise RuntimeError(f"Not all tensor from `T.empty` are returned, only got {value}") + return NotImplemented + + def ctx_with(self, ctx): + self.check_continue_break() + if isinstance(ctx, tir.frame.IRBuilderFrame): + return self.with_frame(ctx) + else: + return super().ctx_with(ctx) + + def assert_expr(self, cond, msg=None): + self.check_continue_break() + cond = unwrap_cond(cond) + if msg is None: + msg = "Assertion failed" + if isinstance(cond, PrimExpr): + self.enter_frame(tir.Assert(cond, msg)) + elif not cond: + raise AssertionError(msg) + + def rval(self, name: str, value: Any) -> Any: + if name in self.name_inside_frame: + frame = self.name_inside_frame[name] + if frame not in self.frames: + raise RuntimeError( + f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n" + f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}." + ) + return self.unwrap_value(value) + + def macro_arg(self, name, value): + annot_value = self.macro_arg_annot.get(name, None) + if annot_value is Var or annot_value is Ref: + if annot_value is Var: + logger.warning("Use `T.Var` as macro annotations is deprecated, please use `T.Ref`") + if isinstance(value, BufferLoad): + if is_var(value.buffer): + return value.buffer + idx = [self.bind("_", idx) for idx in value.indices] + # indices = self.bind(f'_', value.indices) + return Ref(BufferLoad(value.buffer, indices=idx)) + if isinstance(value, BufferRegion): + region = [Range(self.bind("_", x.begin), end=self.bind("_", x.end) if x.end is not None else None) for x in value.region] + return BufferRegion(value.buffer, region=region) + raise ValueError( + f"To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})" + ) + elif isinstance(value, (PrimExpr, int, float)): + return self.bind(name, value) + else: + return value + + def prim_func_arg(self, name, value): + return self.func_annot.create_argument(name, value, self.arg_vt) + # if isinstance(value, (Buffer, Var)): + # return tir.arg(name, value) + # elif value is self.empty: + # raise ValueError(f'Argument `{name}` is not annotated') + # else: + # raise TypeError( + # f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") + + def arg(self, name, value): + if self.find_frame_idx(MacroFrame) is not None: + return self.macro_arg(name, value) + else: + return self.prim_func_arg(name, value) + + def override(self, name: str): + from tilelang.language import serial + + if name == "range": + return serial + raise ValueError(f"Unknown override: {name}") + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +@dataclass +class PrimFuncCreater(Generic[_P, _T]): + func_annot: FuncAnnot + ir_gen: IRGenerator[_P, _T] + orig_func: Callable[_P, _T] + + @property + def annot(self) -> dict[str, Annot]: + return self.func_annot.annots + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_P, _T]: + builder = Builder(self.func_annot) + with builder.prim_func(self.orig_func.__name__): + self.ir_gen.gen(builder)(*args, **kwargs) + res: PrimFunc = builder.get() + res.ir_gen = self.ir_gen + res.orig_func = self.orig_func + res.func_annot = self.func_annot + res.out_idx_override = builder.out_idx or None + return res + + def __repr__(self): + fmt = pprint.pformat({"annot": self.func_annot.annots, "ir_gen": self.ir_gen, "orig_func": self.orig_func}, indent=2) + return f"{self.__class__.__name__}(\n{fmt}\n)" + + +if TYPE_CHECKING: + + class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): + params: list[tvm.tir.Var | tvm.tir.Buffer] + body: tvm.tir.Stmt + ret_type: tvm.ir.Type + buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer] + attrs: tvm.Attrs | None + span: Span | None + ir_gen: IRGenerator[_P, _T] | None + orig_func: Callable[_P, _T] | None + func_annot: FuncAnnot | None + out_idx_override: list[int] | None + +else: + PrimFunc = tvm.tir.PrimFunc + + +@dataclass +class Macro(Generic[_P, _T]): + name: str + orig_func: Callable[_P, _T] + ir_gen: IRGenerator[_P, _T] + annotations: dict[str, Any] + + @property + def source(self) -> str: + return self.ir_gen.source + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + builder = Builder.current() or Builder() + with builder.macro(self.name, self.annotations): + res = self.ir_gen.gen(builder)(*args, **kwargs) + return res + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return id(self) == id(other) + + +def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: + """ + Decorator that converts a Python function into a TileLang macro. + TileLang macro is very similar to PrimFunc, it can be used in prim_func or another macro. + Parameters + ---------- + func : Callable[_P, _T] + The Python function to be converted into a macro. This function will be analyzed + and transformed into an IR generation function. The function can take any parameters + (_P) and return any type (_T). + Returns + ------- + Macro[_P, _T] + A Macro object that wraps the original function with IR generation capabilities. + The returned Macro preserves the original function's signature (parameters _P and + return type _T) while adding metaprogramming capabilities. + Example: + -------- + >>> @macro + ... def my_macro(x: T.int32) -> T.int32: + ... return x ** 2 + >>> @prim_func + ... def my_func(A: T.Tensor((10,), T.int32), B: T.Tensor((10,), T.int32)): + ... with T.Kernel(1) as _: + ... for i in T.serial(10): + ... B[i] = my_macro(A[i]) + See Also + -------- + Macro : The class that wraps macro functions + mutate : The function that transforms Python code into IR generators + """ + + def impl(func: Callable[_P, _T]) -> Macro[_P, _T]: + annotations = get_type_hints(func) + return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations) + + return impl(func) if func is not None else impl + + +from typing import _eval_type + + +def get_type_hints(func): + annot = getattr(func, "__annotations__", None) + if annot is None: + raise TypeError(f"Failed to get function type hints, {func} is not a function") + hints = {} + # Build eval namespaces from function globals plus captured closure variables + # This lets annotations reference symbols like `n`, `h`, or dtype vars + # defined in the outer scope of a nested function. + globalns = func.__globals__ + # Here we add nonlocals into localns, to capture the parameters declared in the parent function + # ```py + # def foo(): + # n = 128 # n is nonlocal + # def bar( + # A: T.Tensor(n, T.float32) # we add nonlocal in its eval context + # ): + # for i in range(n): ... + # ``` + # + # This is incomplete and buggy + # the only bug scenario the function body doesn't use the the parameters + # but such define-no-use scenario is very rare in writing kernels + # + # ```py + # def foo(): + # n = 128 + # def bar(A: T.Tensor((n,), T.float32)): + # ... # empty function, do not use `n` + localns = utils.get_func_nonlocals(func) + for name, value in annot.items(): + if name == "return": + continue + if isinstance(value, tvm.DataType): + hints[name] = value + continue + if value is None: + value = type(None) + if isinstance(value, str): + # if the annotation is string, is can be: (i) a T.float32 like annotations, (ii) a ForwardRef object + # typing doesn't handle (i), it will try to interpret T.float32 + # typing see: T.float32 is str('float32'), and there is no object named `flaot32` and give a NameError + # here we manually interpret it to return T.float32 object + try: + _, v = value.split(".", maxsplit=1) + except ValueError: + v = value + if v in dt._all_dtypes: + try: + hints[name] = eval(value, globalns, localns) + continue + except Exception: + pass + value = ForwardRef(value, is_argument=True, is_class=False) + hints[name] = _eval_type(value, globalns=globalns, localns=localns) + else: + hints[name] = value + return hints + + +def prim_func(func: Callable[_P, _T] = None, *, generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]: + """ + Decorator to create a primitive function (PrimFunc) for TileLang IR generation. + This decorator transforms a Python function into a TileLang primitive function by analyzing + its type annotations and generating intermediate representation (IR) code. It supports both + immediate construction (when all parameters are statically annotated) and generator mode + (for dynamic construction). + Parameters + ---------- + func : Callable[_P, _T], optional + The function to be decorated. Can be None when using decorator with arguments. + generator : bool, default=False + If True, returns a generator function that creates PrimFunc instances on demand. + If False, attempts to create a PrimFunc immediately using type annotations. + Returns + ------- + PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]] + - If `generator=False` and all parameters are statically annotated: returns a PrimFunc instance + - If `generator=True`: returns a callable that generates PrimFunc instances when invoked + - If used without parentheses: returns the decorator implementation function + Examples + -------- + Static annotation mode (immediate construction): + >>> @prim_func + ... def add_kernel(A: T.Buffer((128,), T.float32), + ... B: T.Buffer((128,), T.float32)): + ... for i in T.grid(128): + ... B[i] = A[i] + 1.0 + Generator mode (dynamic construction): + >>> @prim_func(generator=True) + ... def dynamic_kernel(A=T.Tensor((128,), T.float32)): + ... # function body + ... pass + >>> kernel_instance = dynamic_kernel() + With custom parameters: + >>> @prim_func(generator=True) + ... def parameterized_kernel(size: int = 128): + ... # function body using size parameter + ... pass + >>> kernel = parameterized_kernel(size=256) + See Also + -------- + Builder : The IR builder class used for constructing primitive functions + mutate : Function used to generate IR from the decorated function + """ + + def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: + sig = inspect.signature(func) + annot = get_type_hints(func) + + func_annot = FuncAnnot.from_sig_annots(sig, annot) + ir_gen = mutate(func) + + prim_func_generator = PrimFuncCreater(func_annot, ir_gen, orig_func=func) + + if func_annot.is_all_static(): + args = func_annot.get_all_static_args() + return prim_func_generator(**args) + else: + if generator is False: + unknown_args = func_annot.get_compile_time_unknown_args() + raise ValueError( + f"Cannot create PrimFunc for `{func.__name__}`, some arguments are not compile-time known, \n" + f"Annotations:\n{func_annot.annots}" + f"Unknown Args: {unknown_args}" + ) + return prim_func_generator + + return impl(func) if func is not None else impl diff --git a/tilelang/original/tilelang/language/v2/dtypes.py b/tilelang/original/tilelang/language/v2/dtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..a42ba5a675dae04d8decaa8281a97da6d0117085 --- /dev/null +++ b/tilelang/original/tilelang/language/v2/dtypes.py @@ -0,0 +1,728 @@ +from tilelang import tvm +from tvm import ir +import torch +from typing import Generic, TypeVar, Union, TYPE_CHECKING +from tvm import tir +import tvm.script.ir_builder.tir._ffi_api as tb_ffi +import numpy as np + +_T = TypeVar("_T") + +if TYPE_CHECKING: + + class dtype(Generic[_T]): + def as_torch(self) -> torch.dtype: ... +else: + dtype = tvm.DataType + +# Python 3.9 compatibility: avoid PEP 604 unions at runtime +AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] + +_PYTHON_DTYPE_TO_STR = { + bool: "bool", + int: "int32", + float: "float32", +} + +_NUMPY_DTYPE_TO_STR = { + np.bool_: "bool", + np.short: "int16", + np.int_: "int64", + np.longlong: "int64", + np.half: "float16", + np.double: "float64", + np.int8: "int8", + np.int16: "int16", + np.int32: "int32", + np.int64: "int64", + np.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + np.float16: "float16", + np.float32: "float32", + np.float64: "float64", +} + +_NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()}) + +_TORCH_DTYPE_TO_STR = { + torch.bool: "bool", + torch.short: "int16", + torch.int: "int32", + torch.long: "int64", + torch.half: "float16", + torch.float: "float32", + torch.double: "float64", + torch.int8: "int8", + torch.int16: "int16", + torch.int32: "int32", + torch.int64: "int64", + torch.uint8: "uint8", + torch.uint16: "uint16", + torch.uint32: "uint32", + torch.uint64: "uint64", + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.bfloat16: "bfloat16", +} + +_extended_torch_dtypes = [ + ("float8_e4m3fn",), + ("float8_e4m3fnuz",), + ("float8_e5m2",), + ("float8_e5m2fnuz",), + ("float8_e8m0fnu",), + ("float4_e2m1fnx2",), +] +for dtype_name_tuple in _extended_torch_dtypes: + dtype_name = dtype_name_tuple[0] + torch_dtype = getattr(torch, dtype_name, None) + if torch_dtype is not None: + _TORCH_DTYPE_TO_STR[torch_dtype] = dtype_name + + +_CANONICAL_TO_DISPLAY_STR = { + "double": "float64", + "float": "float32", + "int": "int32", + "long": "int64", + "short": "int16", + "uint": "uint32", + "ulong": "uint64", +} + +_STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} + +# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()} + +_DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR} + +_STR_TO_TVM_DTYPE_CALL = { + "bool": "Boolean", + "int4": "Int4", + "int8": "Int8", + "int16": "Int16", + "int32": "Int32", + "int64": "Int64", + "uint8": "UInt8", + "uint16": "UInt16", + "uint32": "UInt32", + "uint64": "UInt64", + "float16": "Float16", + "float32": "Float32", + "float64": "Float64", + "bfloat16": "BFloat16", + "float8_e4m3": "Float8E4M3", + "float8_e4m3fn": "Float8E4M3FN", + "float8_e4m3fnuz": "Float8E4M3FNUZ", + "float8_e5m2": "Float8E5M2", + "float8_e5m2fnuz": "Float8E5M2FNUZ", + "float8_e8m0fnu": "Float8E8M0FNU", +} + +int_ = int + + +def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: + if isinstance(expr, int_): + return tvm.tir.const(expr, dtype=self) + if self in _STR_TO_TVM_DTYPE_CALL: + attr = _STR_TO_TVM_DTYPE_CALL[self] + call = getattr(tb_ffi, attr, None) + return call(expr, is_size_var) + # try to construct the ffi call + if self.startswith("uint"): + val = "UInt" + self[4:] + elif self.startswith("int"): + val = "Int" + self[3:] + elif self.startswith("float"): + val = "Float" + self[5:] + elif self.startswith("bfloat"): + val = "BFloat" + self[6:] + else: + raise TypeError(f"Invalid type {self}") + if "_" in val: + first, second = val.split("_", maxsplit=1) + val = first + second.upper() + call = getattr(tb_ffi, val, None) + if call is None: + raise TypeError( + f"Convert to datatype `{self}` is not supported by tvm\ncalling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`" + ) + return call(expr, is_size_var) + + +def __dtype_as_torch__(self: dtype) -> torch.dtype: + """Convert TileLang dtype to PyTorch dtype.""" + dtype_str = str(self) + + if dtype_str == "float8_e4m3": + # Check if we're on HIP (AMD ROCm) or CUDA + if torch.version.hip is not None: + # HIP backend - use float8_e4m3fnuz + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0" + ) + return torch.float8_e4m3fnuz + else: + # CUDA backend - use float8_e4m3fn + assert hasattr(torch, "float8_e4m3fn"), ( + "torch.float8_e4m3fn is not supported in this version of torch. Please upgrade torch >= 2.1.0" + ) + return torch.float8_e4m3fn + elif dtype_str == "float8_e5m2": + assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0" + return torch.float8_e5m2 + elif dtype_str == "e4m3fnuz_float8": + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0" + ) + return torch.float8_e4m3fnuz + elif dtype_str == "float8_e8m0fnu": + assert hasattr(torch, "float8_e8m0fnu"), ( + "torch.float8_e8m0fnu is not supported in this version of torch. Please upgrade torch >= 2.8.0" + ) + return torch.float8_e8m0fnu + elif dtype_str == "float4_e2m1fnx2": + assert hasattr(torch, "float4_e2m1fnx2"), ( + "torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0" + ) + return torch.float4_e2m1fnx2 + elif dtype_str in _STR_TO_TORCH_DTYPE: + return _STR_TO_TORCH_DTYPE[dtype_str] + + raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}") + + +__orig_dtype_new = dtype.__new__ + + +def __dtype_new__(cls, value: AnyDType) -> dtype: + if isinstance(value, str): + return __orig_dtype_new(cls, _CANONICAL_TO_DISPLAY_STR.get(value, value)) + elif value in _DTYPE_TO_STR: + return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) + else: + expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values())) + raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") + + +dtype.__call__ = __dtype_call__ +dtype.__new__ = __dtype_new__ +dtype.as_torch = __dtype_as_torch__ + + +def get_tvm_dtype(value: AnyDType) -> dtype: + if isinstance(value, (dtype, ir.Type)): + return value + return dtype(value) + + +if TYPE_CHECKING: + # yapf: disable + class bool(dtype): ... + class short(dtype): ... + class int(dtype): ... + class uint(dtype): ... + class long(dtype): ... + class half(dtype): ... + class float(dtype): ... + class double(dtype): ... + class int4(dtype): ... + class int8(dtype): ... + class int16(dtype): ... + class int32(dtype): ... + class int64(dtype): ... + class int8x2(dtype): ... + class int16x2(dtype): ... + class int32x2(dtype): ... + class int64x2(dtype): ... + class int8x4(dtype): ... + class int16x4(dtype): ... + class int32x4(dtype): ... + class int64x4(dtype): ... + class int8x8(dtype): ... + class int16x8(dtype): ... + class int32x8(dtype): ... + class int64x8(dtype): ... + class int8x16(dtype): ... + class int16x16(dtype): ... + class int32x16(dtype): ... + class int64x16(dtype): ... + class int8x32(dtype): ... + class int16x32(dtype): ... + class int32x32(dtype): ... + class int64x32(dtype): ... + class int8x64(dtype): ... + class int16x64(dtype): ... + class int32x64(dtype): ... + class int64x64(dtype): ... + class uint8(dtype): ... + class uint16(dtype): ... + class uint32(dtype): ... + class uint64(dtype): ... + class uint8x2(dtype): ... + class uint16x2(dtype): ... + class uint32x2(dtype): ... + class uint64x2(dtype): ... + class uint8x4(dtype): ... + class uint16x4(dtype): ... + class uint32x4(dtype): ... + class uint64x4(dtype): ... + class uint8x8(dtype): ... + class uint16x8(dtype): ... + class uint32x8(dtype): ... + class uint64x8(dtype): ... + class uint8x16(dtype): ... + class uint16x16(dtype): ... + class uint32x16(dtype): ... + class uint64x16(dtype): ... + class uint8x32(dtype): ... + class uint16x32(dtype): ... + class uint32x32(dtype): ... + class uint64x32(dtype): ... + class uint8x64(dtype): ... + class uint16x64(dtype): ... + class uint32x64(dtype): ... + class uint64x64(dtype): ... + class float16(dtype): ... + class float32(dtype): ... + class float64(dtype): ... + class float16x2(dtype): ... + class float32x2(dtype): ... + class float64x2(dtype): ... + class float16x4(dtype): ... + class float32x4(dtype): ... + class float64x4(dtype): ... + class float16x8(dtype): ... + class float32x8(dtype): ... + class float64x8(dtype): ... + class float16x16(dtype): ... + class float32x16(dtype): ... + class float64x16(dtype): ... + class float16x32(dtype): ... + class float32x32(dtype): ... + class float64x32(dtype): ... + class float16x64(dtype): ... + class float32x64(dtype): ... + class float64x64(dtype): ... + class float8_e3m4(dtype): ... + class float8_e3m4x2(dtype): ... + class float8_e3m4x4(dtype): ... + class float8_e3m4x8(dtype): ... + class float8_e3m4x16(dtype): ... + class float8_e3m4x32(dtype): ... + class float8_e3m4x64(dtype): ... + class float8_e4m3(dtype): ... + class float8_e4m3x2(dtype): ... + class float8_e4m3x4(dtype): ... + class float8_e4m3x8(dtype): ... + class float8_e4m3x16(dtype): ... + class float8_e4m3x32(dtype): ... + class float8_e4m3x64(dtype): ... + class float8_e4m3b11fnuz(dtype): ... + class float8_e4m3b11fnuzx2(dtype): ... + class float8_e4m3b11fnuzx4(dtype): ... + class float8_e4m3b11fnuzx8(dtype): ... + class float8_e4m3b11fnuzx16(dtype): ... + class float8_e4m3b11fnuzx32(dtype): ... + class float8_e4m3b11fnuzx64(dtype): ... + class float8_e4m3fn(dtype): ... + class float8_e4m3fnx2(dtype): ... + class float8_e4m3fnx4(dtype): ... + class float8_e4m3fnx8(dtype): ... + class float8_e4m3fnx16(dtype): ... + class float8_e4m3fnx32(dtype): ... + class float8_e4m3fnx64(dtype): ... + class float8_e4m3fnuz(dtype): ... + class float8_e4m3fnuzx2(dtype): ... + class float8_e4m3fnuzx4(dtype): ... + class float8_e4m3fnuzx8(dtype): ... + class float8_e4m3fnuzx16(dtype): ... + class float8_e4m3fnuzx32(dtype): ... + class float8_e4m3fnuzx64(dtype): ... + class float8_e5m2(dtype): ... + class float8_e5m2x2(dtype): ... + class float8_e5m2x4(dtype): ... + class float8_e5m2x8(dtype): ... + class float8_e5m2x16(dtype): ... + class float8_e5m2x32(dtype): ... + class float8_e5m2x64(dtype): ... + class float8_e5m2fnuz(dtype): ... + class float8_e5m2fnuzx2(dtype): ... + class float8_e5m2fnuzx4(dtype): ... + class float8_e5m2fnuzx8(dtype): ... + class float8_e5m2fnuzx16(dtype): ... + class float8_e5m2fnuzx32(dtype): ... + class float8_e5m2fnuzx64(dtype): ... + class float8_e8m0fnu(dtype): ... + class float8_e8m0fnux2(dtype): ... + class float8_e8m0fnux4(dtype): ... + class float8_e8m0fnux8(dtype): ... + class float8_e8m0fnux16(dtype): ... + class float8_e8m0fnux32(dtype): ... + class float8_e8m0fnux64(dtype): ... + class float6_e2m3fn(dtype): ... + class float6_e2m3fnx2(dtype): ... + class float6_e2m3fnx4(dtype): ... + class float6_e2m3fnx8(dtype): ... + class float6_e2m3fnx16(dtype): ... + class float6_e2m3fnx32(dtype): ... + class float6_e2m3fnx64(dtype): ... + class float6_e3m2fn(dtype): ... + class float6_e3m2fnx2(dtype): ... + class float6_e3m2fnx4(dtype): ... + class float6_e3m2fnx8(dtype): ... + class float6_e3m2fnx16(dtype): ... + class float6_e3m2fnx32(dtype): ... + class float6_e3m2fnx64(dtype): ... + class float4_e2m1fn(dtype): ... + class float4_e2m1fnx2(dtype): ... + class float4_e2m1fnx4(dtype): ... + class float4_e2m1fnx8(dtype): ... + class float4_e2m1fnx16(dtype): ... + class float4_e2m1fnx32(dtype): ... + class float4_e2m1fnx64(dtype): ... + class bfloat16(dtype): ... + # yapf: enable + +else: + bool = dtype("bool") + short = dtype("int16") + int = dtype("int32") + uint = dtype("uint32") + long = dtype("int64") + half = dtype("float16") + float = dtype("float32") + double = dtype("float64") + int4 = dtype("int4") + int8 = dtype("int8") + int16 = dtype("int16") + int32 = dtype("int32") + int64 = dtype("int64") + int8x2 = dtype("int8x2") + int16x2 = dtype("int16x2") + int32x2 = dtype("int32x2") + int64x2 = dtype("int64x2") + int8x4 = dtype("int8x4") + int16x4 = dtype("int16x4") + int32x4 = dtype("int32x4") + int64x4 = dtype("int64x4") + int8x8 = dtype("int8x8") + int16x8 = dtype("int16x8") + int32x8 = dtype("int32x8") + int64x8 = dtype("int64x8") + int8x16 = dtype("int8x16") + int16x16 = dtype("int16x16") + int32x16 = dtype("int32x16") + int64x16 = dtype("int64x16") + int8x32 = dtype("int8x32") + int16x32 = dtype("int16x32") + int32x32 = dtype("int32x32") + int64x32 = dtype("int64x32") + int8x64 = dtype("int8x64") + int16x64 = dtype("int16x64") + int32x64 = dtype("int32x64") + int64x64 = dtype("int64x64") + uint8 = dtype("uint8") + uint16 = dtype("uint16") + uint32 = dtype("uint32") + uint64 = dtype("uint64") + uint8x2 = dtype("uint8x2") + uint16x2 = dtype("uint16x2") + uint32x2 = dtype("uint32x2") + uint64x2 = dtype("uint64x2") + uint8x4 = dtype("uint8x4") + uint16x4 = dtype("uint16x4") + uint32x4 = dtype("uint32x4") + uint64x4 = dtype("uint64x4") + uint8x8 = dtype("uint8x8") + uint16x8 = dtype("uint16x8") + uint32x8 = dtype("uint32x8") + uint64x8 = dtype("uint64x8") + uint8x16 = dtype("uint8x16") + uint16x16 = dtype("uint16x16") + uint32x16 = dtype("uint32x16") + uint64x16 = dtype("uint64x16") + uint8x32 = dtype("uint8x32") + uint16x32 = dtype("uint16x32") + uint32x32 = dtype("uint32x32") + uint64x32 = dtype("uint64x32") + uint8x64 = dtype("uint8x64") + uint16x64 = dtype("uint16x64") + uint32x64 = dtype("uint32x64") + uint64x64 = dtype("uint64x64") + float16 = dtype("float16") + float32 = dtype("float32") + float64 = dtype("float64") + float16x2 = dtype("float16x2") + float32x2 = dtype("float32x2") + float64x2 = dtype("float64x2") + float16x4 = dtype("float16x4") + float32x4 = dtype("float32x4") + float64x4 = dtype("float64x4") + float16x8 = dtype("float16x8") + float32x8 = dtype("float32x8") + float64x8 = dtype("float64x8") + float16x16 = dtype("float16x16") + float32x16 = dtype("float32x16") + float64x16 = dtype("float64x16") + float16x32 = dtype("float16x32") + float32x32 = dtype("float32x32") + float64x32 = dtype("float64x32") + float16x64 = dtype("float16x64") + float32x64 = dtype("float32x64") + float64x64 = dtype("float64x64") + float8_e3m4 = dtype("float8_e3m4") + float8_e3m4x2 = dtype("float8_e3m4x2") + float8_e3m4x4 = dtype("float8_e3m4x4") + float8_e3m4x8 = dtype("float8_e3m4x8") + float8_e3m4x16 = dtype("float8_e3m4x16") + float8_e3m4x32 = dtype("float8_e3m4x32") + float8_e3m4x64 = dtype("float8_e3m4x64") + float8_e4m3 = dtype("float8_e4m3") + float8_e4m3x2 = dtype("float8_e4m3x2") + float8_e4m3x4 = dtype("float8_e4m3x4") + float8_e4m3x8 = dtype("float8_e4m3x8") + float8_e4m3x16 = dtype("float8_e4m3x16") + float8_e4m3x32 = dtype("float8_e4m3x32") + float8_e4m3x64 = dtype("float8_e4m3x64") + float8_e4m3b11fnuz = dtype("float8_e4m3b11fnuz") + float8_e4m3b11fnuzx2 = dtype("float8_e4m3b11fnuzx2") + float8_e4m3b11fnuzx4 = dtype("float8_e4m3b11fnuzx4") + float8_e4m3b11fnuzx8 = dtype("float8_e4m3b11fnuzx8") + float8_e4m3b11fnuzx16 = dtype("float8_e4m3b11fnuzx16") + float8_e4m3b11fnuzx32 = dtype("float8_e4m3b11fnuzx32") + float8_e4m3b11fnuzx64 = dtype("float8_e4m3b11fnuzx64") + float8_e4m3fn = dtype("float8_e4m3fn") + float8_e4m3fnx2 = dtype("float8_e4m3fnx2") + float8_e4m3fnx4 = dtype("float8_e4m3fnx4") + float8_e4m3fnx8 = dtype("float8_e4m3fnx8") + float8_e4m3fnx16 = dtype("float8_e4m3fnx16") + float8_e4m3fnx32 = dtype("float8_e4m3fnx32") + float8_e4m3fnx64 = dtype("float8_e4m3fnx64") + float8_e4m3fnuz = dtype("float8_e4m3fnuz") + float8_e4m3fnuzx2 = dtype("float8_e4m3fnuzx2") + float8_e4m3fnuzx4 = dtype("float8_e4m3fnuzx4") + float8_e4m3fnuzx8 = dtype("float8_e4m3fnuzx8") + float8_e4m3fnuzx16 = dtype("float8_e4m3fnuzx16") + float8_e4m3fnuzx32 = dtype("float8_e4m3fnuzx32") + float8_e4m3fnuzx64 = dtype("float8_e4m3fnuzx64") + float8_e5m2 = dtype("float8_e5m2") + float8_e5m2x2 = dtype("float8_e5m2x2") + float8_e5m2x4 = dtype("float8_e5m2x4") + float8_e5m2x8 = dtype("float8_e5m2x8") + float8_e5m2x16 = dtype("float8_e5m2x16") + float8_e5m2x32 = dtype("float8_e5m2x32") + float8_e5m2x64 = dtype("float8_e5m2x64") + float8_e5m2fnuz = dtype("float8_e5m2fnuz") + float8_e5m2fnuzx2 = dtype("float8_e5m2fnuzx2") + float8_e5m2fnuzx4 = dtype("float8_e5m2fnuzx4") + float8_e5m2fnuzx8 = dtype("float8_e5m2fnuzx8") + float8_e5m2fnuzx16 = dtype("float8_e5m2fnuzx16") + float8_e5m2fnuzx32 = dtype("float8_e5m2fnuzx32") + float8_e5m2fnuzx64 = dtype("float8_e5m2fnuzx64") + float8_e8m0fnu = dtype("float8_e8m0fnu") + float8_e8m0fnux2 = dtype("float8_e8m0fnux2") + float8_e8m0fnux4 = dtype("float8_e8m0fnux4") + float8_e8m0fnux8 = dtype("float8_e8m0fnux8") + float8_e8m0fnux16 = dtype("float8_e8m0fnux16") + float8_e8m0fnux32 = dtype("float8_e8m0fnux32") + float8_e8m0fnux64 = dtype("float8_e8m0fnux64") + float6_e2m3fn = dtype("float6_e2m3fn") + float6_e2m3fnx2 = dtype("float6_e2m3fnx2") + float6_e2m3fnx4 = dtype("float6_e2m3fnx4") + float6_e2m3fnx8 = dtype("float6_e2m3fnx8") + float6_e2m3fnx16 = dtype("float6_e2m3fnx16") + float6_e2m3fnx32 = dtype("float6_e2m3fnx32") + float6_e2m3fnx64 = dtype("float6_e2m3fnx64") + float6_e3m2fn = dtype("float6_e3m2fn") + float6_e3m2fnx2 = dtype("float6_e3m2fnx2") + float6_e3m2fnx4 = dtype("float6_e3m2fnx4") + float6_e3m2fnx8 = dtype("float6_e3m2fnx8") + float6_e3m2fnx16 = dtype("float6_e3m2fnx16") + float6_e3m2fnx32 = dtype("float6_e3m2fnx32") + float6_e3m2fnx64 = dtype("float6_e3m2fnx64") + float4_e2m1fn = dtype("float4_e2m1fn") + float4_e2m1fnx2 = dtype("float4_e2m1fnx2") + float4_e2m1fnx4 = dtype("float4_e2m1fnx4") + float4_e2m1fnx8 = dtype("float4_e2m1fnx8") + float4_e2m1fnx16 = dtype("float4_e2m1fnx16") + float4_e2m1fnx32 = dtype("float4_e2m1fnx32") + float4_e2m1fnx64 = dtype("float4_e2m1fnx64") + bfloat16 = dtype("bfloat16") + +_all_dtypes = { + "bool", + "short", + "int", + "uint", + "long", + "half", + "float", + "double", + "int4", + "int8", + "int16", + "int32", + "int64", + "int8x2", + "int16x2", + "int32x2", + "int64x2", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "int8x8", + "int16x8", + "int32x8", + "int64x8", + "int8x16", + "int16x16", + "int32x16", + "int64x16", + "int8x32", + "int16x32", + "int32x32", + "int64x32", + "int8x64", + "int16x64", + "int32x64", + "int64x64", + "uint8", + "uint16", + "uint32", + "uint64", + "uint8x2", + "uint16x2", + "uint32x2", + "uint64x2", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "uint8x8", + "uint16x8", + "uint32x8", + "uint64x8", + "uint8x16", + "uint16x16", + "uint32x16", + "uint64x16", + "uint8x32", + "uint16x32", + "uint32x32", + "uint64x32", + "uint8x64", + "uint16x64", + "uint32x64", + "uint64x64", + "float16", + "float32", + "float64", + "float16x2", + "float32x2", + "float64x2", + "float16x4", + "float32x4", + "float64x4", + "float16x8", + "float32x8", + "float64x8", + "float16x16", + "float32x16", + "float64x16", + "float16x32", + "float32x32", + "float64x32", + "float16x64", + "float32x64", + "float64x64", + "float8_e3m4", + "float8_e3m4x2", + "float8_e3m4x4", + "float8_e3m4x8", + "float8_e3m4x16", + "float8_e3m4x32", + "float8_e3m4x64", + "float8_e4m3", + "float8_e4m3x2", + "float8_e4m3x4", + "float8_e4m3x8", + "float8_e4m3x16", + "float8_e4m3x32", + "float8_e4m3x64", + "float8_e4m3b11fnuz", + "float8_e4m3b11fnuzx2", + "float8_e4m3b11fnuzx4", + "float8_e4m3b11fnuzx8", + "float8_e4m3b11fnuzx16", + "float8_e4m3b11fnuzx32", + "float8_e4m3b11fnuzx64", + "float8_e4m3fn", + "float8_e4m3fnx2", + "float8_e4m3fnx4", + "float8_e4m3fnx8", + "float8_e4m3fnx16", + "float8_e4m3fnx32", + "float8_e4m3fnx64", + "float8_e4m3fnuz", + "float8_e4m3fnuzx2", + "float8_e4m3fnuzx4", + "float8_e4m3fnuzx8", + "float8_e4m3fnuzx16", + "float8_e4m3fnuzx32", + "float8_e4m3fnuzx64", + "float8_e5m2", + "float8_e5m2x2", + "float8_e5m2x4", + "float8_e5m2x8", + "float8_e5m2x16", + "float8_e5m2x32", + "float8_e5m2x64", + "float8_e5m2fnuz", + "float8_e5m2fnuzx2", + "float8_e5m2fnuzx4", + "float8_e5m2fnuzx8", + "float8_e5m2fnuzx16", + "float8_e5m2fnuzx32", + "float8_e5m2fnuzx64", + "float8_e8m0fnu", + "float8_e8m0fnux2", + "float8_e8m0fnux4", + "float8_e8m0fnux8", + "float8_e8m0fnux16", + "float8_e8m0fnux32", + "float8_e8m0fnux64", + "float6_e2m3fn", + "float6_e2m3fnx2", + "float6_e2m3fnx4", + "float6_e2m3fnx8", + "float6_e2m3fnx16", + "float6_e2m3fnx32", + "float6_e2m3fnx64", + "float6_e3m2fn", + "float6_e3m2fnx2", + "float6_e3m2fnx4", + "float6_e3m2fnx8", + "float6_e3m2fnx16", + "float6_e3m2fnx32", + "float6_e3m2fnx64", + "float4_e2m1fn", + "float4_e2m1fnx2", + "float4_e2m1fnx4", + "float4_e2m1fnx8", + "float4_e2m1fnx16", + "float4_e2m1fnx32", + "float4_e2m1fnx64", + "bfloat16", +} + +__all__ = list(_all_dtypes) + [ + "dtype", + "AnyDType", + "get_tvm_dtype", +] diff --git a/tilelang/original/tilelang/language/v2/utils.py b/tilelang/original/tilelang/language/v2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..207bd92ad3290fdad79a3c77ece32c9698163751 --- /dev/null +++ b/tilelang/original/tilelang/language/v2/utils.py @@ -0,0 +1,98 @@ +from __future__ import annotations +import ast +import inspect +from typing import Any, Callable, Literal +from tilelang import env +from hashlib import sha256 +from tvm import tir +import linecache + + +def disk_compile(source, name): + cache_dir = env.TILELANG_CACHE_DIR + if cache_dir is not None: + import os + + save_dir = os.path.join(cache_dir, "py-cache") + os.makedirs(save_dir, exist_ok=True) + hash_sfx = sha256(source.encode("utf-8")).hexdigest()[:8] + path = os.path.join(save_dir, f"{name}.{hash_sfx}.py") + with open(path, "w") as f: + f.write(source) + linecache.cache[path] = (len(source), None, source.splitlines(), path) + return compile(source, path, "exec") + + +def _remove_leading_ident(source: str): + lines = source.splitlines() + if not lines: + return source + ident_size = len(lines[0]) - len(lines[0].lstrip()) + return "\n".join([line[ident_size:] if len(line) >= ident_size else line for line in lines]) + + +def get_func_nonlocals(func): + """A modified version of `inspect.getclosurevars`""" + + if inspect.ismethod(func): + func = func.__func__ + + if not inspect.isfunction(func): + raise TypeError(f"{func!r} is not a Python function") + + code = func.__code__ + # Nonlocal references are named in co_freevars and resolved + # by looking them up in __closure__ by positional index + nonlocal_vars = {} + if func.__closure__ is not None: + for var, cell in zip(code.co_freevars, func.__closure__): + try: + nonlocal_vars[var] = cell.cell_contents + except ValueError as err: + # cell_contents may raise ValueError if the cell is empty. + if "empty" not in str(err): + raise + return nonlocal_vars + + +def get_ast(func: Callable): + _, start = inspect.getsourcelines(func) + filename = inspect.getsourcefile(func) or inspect.getfile(func) + source = inspect.getsource(func) + source = _remove_leading_ident(source) + source = "\n" * (start - 1) + source + tree = ast.parse(source, filename=filename) + return tree + + +CompileMethod = Literal["direct", "disk"] + + +def get_compiled_object(source: str | ast.AST, name: str, filename: str = None, globals: dict[str, Any] = None): + if isinstance(source, ast.AST): + assert filename is not None, "filename must be provided when source is an AST" + try: + if isinstance(source, ast.AST): + ast.fix_missing_locations(source) + compiled = compile(source, filename, "exec") + else: + compiled = disk_compile(source, name) + except Exception as e: + source_str = source if isinstance(source, str) else ast.unparse(source) + raise RuntimeError(f"Failed to compile source for {name}, Error: {e}:\n{source_str}") from e + locs = {} + exec(compiled, globals, locs) + return locs[name] + + +def construct_strides(shape: tuple[Any, ...], allow_prim_expr: bool = True) -> tuple[Any, ...]: + """Construct row-major strides from shape.""" + strides = [] + stride = 1 + for s in shape[::-1]: + strides.append(stride) + stride *= s + if not allow_prim_expr and isinstance(stride, tir.PrimExpr): + raise ValueError("Cannot construct strides with PrimExpr when allow_prim_expr is False.") + strides = tuple(reversed(strides)) + return strides diff --git a/tilelang/original/tilelang/language/warpgroup.py b/tilelang/original/tilelang/language/warpgroup.py new file mode 100644 index 0000000000000000000000000000000000000000..77cf6924583262aa4d34b708f1ade38a64347477 --- /dev/null +++ b/tilelang/original/tilelang/language/warpgroup.py @@ -0,0 +1,57 @@ +"""The language interface for tl programs.""" + +from tvm.script.ir_builder.tir.frame import TIRFrame +from tvm.ffi import register_object +from tilelang import _ffi_api +from .kernel import get_thread_bindings, get_thread_extents + + +@register_object("tl.WarpSpecializeFrame") +class WarpSpecializeFrame(TIRFrame): + """ + WarpSpecializeFrame is a custom TIRFrame that manages warp group indices + and handles the entry and exit of the kernel launch scope. + """ + + +def WarpSpecialize(*warp_group_idx): + """Tools to construct a warp group frame. + + Parameters + ---------- + warp_group_idx : int + A integer representing warp group index + Or a list of integers representing blockDim.(x|y|z) + if the value is -1, we skip the threadIdx.x binding. + + Returns + ------- + res : Tuple[frame.LaunchThreadFrame] + The result LaunchThreadFrame. + Examples: + >>> T.ws(0) -> if tx < 128 + >>> T.ws(1) -> if tx >= 128 and tx < 256 + >>> T.ws(0, 1) -> if tx < 128 or (tx >= 128 and tx < 256) + """ + id_x, id_y, id_z = get_thread_bindings() + ex_x, ex_y, ex_z = get_thread_extents() + tid = id_x + if ex_y > 1: + tid = id_y * ex_x + tid + if ex_z > 1: + tid = id_z * (ex_y * ex_x) + tid + + # only available for nvidia gpus. + warp_group_size = 128 + + warp_group_ids: list[int] = [] + for warp_group_id in warp_group_idx: + warp_group_ids.append(warp_group_id) + + assert len(warp_group_ids) > 0, "warp_group_idx must be non-empty" + + return _ffi_api.WarpSpecialize(warp_group_ids, tid, warp_group_size) + + +# Alias for WarpSpecialize for more concise usage +ws = WarpSpecialize diff --git a/tilelang/original/tilelang/layout/__init__.py b/tilelang/original/tilelang/layout/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..777802d2c78be3defa35d697440552161cd3d422 --- /dev/null +++ b/tilelang/original/tilelang/layout/__init__.py @@ -0,0 +1,16 @@ +"""Wrapping Layouts.""" +# pylint: disable=invalid-name, unsupported-binary-operation + +from .layout import Layout # noqa: F401 +from .fragment import Fragment # noqa: F401 +from .swizzle import ( + make_swizzled_layout, # noqa: F401 + make_volta_swizzled_layout, # noqa: F401 + make_wgmma_swizzled_layout, # noqa: F401 + make_tcgen05mma_swizzled_layout, # noqa: F401 + make_full_bank_swizzled_layout, # noqa: F401 + make_half_bank_swizzled_layout, # noqa: F401 + make_quarter_bank_swizzled_layout, # noqa: F401 + make_linear_layout, # noqa: F401 +) +from .gemm_sp import make_cutlass_metadata_layout # noqa: F401 diff --git a/tilelang/original/tilelang/layout/fragment.py b/tilelang/original/tilelang/layout/fragment.py new file mode 100644 index 0000000000000000000000000000000000000000..256a7d5ee169c1545c38ad192131ed115810a2a8 --- /dev/null +++ b/tilelang/original/tilelang/layout/fragment.py @@ -0,0 +1,205 @@ +"""Wrapping Layouts.""" + +# pylint: disable=invalid-name, unsupported-binary-operation +import tvm +import tvm_ffi +from tvm.ir import Range +from tvm.tir import IterVar, Var, PrimExpr, IndexMap +from tilelang import _ffi_api +from tilelang.layout import Layout + + +@tvm_ffi.register_object("tl.Fragment") +class Fragment(Layout): + """ + A Fragment layout object that encapsulates iteration variables (forward_vars), + thread iteration variables (forward_thread), and index transformations + (forward_index). This class supports replication (thread_replicate) and + index mapping for fine-grained control over multi-dimensional data layouts. + """ + + # Disable the linter warning about not calling super().__init__() + # because this object is created via TVM's FFI constructor mechanism. + # pylint: disable=super-init-not-called + def __init__(self, shape, forward_fn=None, forward_thread_fn=None, replicate=1, forward_index_fn=None): + """ + Initialize the Fragment with iteration variables and optional thread replication. + + Parameters + ---------- + shape : list[int] + A list of integer sizes for each dimension of this fragment. + forward_fn : callable, optional + A function that takes the iteration variables, plus optionally a replicate + IterVar, and returns a tuple: (forward_thread, forward_index). + It is used when you want to compute both thread mapping and index mapping + from the shape variables. + forward_thread_fn : callable, optional + A function that takes iteration variables (plus optionally a replicate Var) + and returns an IterVar representing the thread index. This is used if + `forward_fn` is not provided, and only the thread mapping is derived + here while the index mapping is derived separately via `forward_index_fn`. + replicate : int, optional + How many times to replicate the iteration over the threads, typically + used for multi-threading or replication in the hardware threads. Defaults to 1. + forward_index_fn : callable, optional + A function that takes iteration variables and returns an index or list + of indices for this fragment. Used when `forward_fn` is None and + the index transformation is derived separately. + """ + + # Create a list of IterVar objects based on shape dimensions + # Each dimension is assigned a range from 0..size and a Var like i0, i1, etc. + forward_vars = [] + for idx, size in enumerate(shape): + iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0) + forward_vars.append(iv) + + # Collect the underlying variables (i.e., Var objects) from the IterVars + vars = [iv.var for iv in forward_vars] + + # Initialize placeholders for optional outputs + forward_thread: IterVar = None + forward_index: tvm.ir.container.Array = None + thread_replicate: IterVar = None + + # If a forward_fn is provided, use it to derive both thread mapping and indices + if forward_fn is not None: + # If replication is greater than 1, create a replicate IterVar + # and pass it to forward_fn + if replicate > 1: + thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0) + forward_thread, forward_index = forward_fn(*vars, thread_replicate) + else: + thread_replicate = None + forward_thread, forward_index = forward_fn(*vars) + else: + # If no forward_fn is provided, compute forward_index (if any) via forward_index_fn + forward_index = forward_index_fn(*vars) if forward_index_fn else None + # Then compute forward_thread via forward_thread_fn + if replicate > 1: + thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0) + forward_thread = forward_thread_fn(*vars, thread_replicate.var) + else: + thread_replicate = None + forward_thread = forward_thread_fn(*vars) + + # Ensure forward_index is an array if it isn't None + if forward_index is None: + forward_index = [] + elif not isinstance(forward_index, tvm.ir.container.Array): + forward_index = [forward_index] + + # Call TVM FFI constructor to set up internal data structures + self.__init_handle_by_constructor__( + _ffi_api.Fragment, + forward_vars, + forward_index, + forward_thread, + thread_replicate, + ) + + @property + def thread(self): + """ + Returns the forward_thread (IterVar) of the Fragment, representing + the thread dimension or mapping. + """ + return _ffi_api.Fragment_thread(self) + + def get_thread_size(self): + """ + Returns the extent (range size) of the thread dimension. + If the Fragment was replicated over threads, this will reflect + the number of threads. + """ + return _ffi_api.Fragment_thread_size(self) + + def repeat(self, repeats, repeat_on_thread: bool = False, lower_dim_first: bool = True) -> "Fragment": + """ + Returns a new Fragment that repeats the iteration space a given number of times. + + Parameters + ---------- + repeats : int + Number of times to repeat. + repeat_on_thread : bool, optional + If set, the repeat will happen on the thread dimension. + lower_dim_first : bool, optional + If set to True, repeat on lower dimensions first. + + Returns + ------- + Fragment + A new Fragment with the repeated iteration space. + """ + return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) + + def replicate(self, replicate: int) -> "Fragment": + """ + Replicate the Fragment across a new thread dimension. + + Parameters + ---------- + replicate : int + The replication factor or number of threads. + + Returns + ------- + Fragment + A new Fragment with an additional replicate dimension. + """ + return _ffi_api.Fragment_replicate(self, replicate) + + def condense_rep_var(self) -> "Fragment": + """ + Condense or fold the replicate variable into the existing iteration space. + This operation may be used to reduce dimensionality if the replicate variable + is no longer needed as a separate dimension. + + Returns + ------- + Fragment + A new Fragment where the replicate variable is condensed. + """ + return _ffi_api.Fragment_condense_rep_var(self) + + def map_forward_thread(self, indices: list[PrimExpr]) -> PrimExpr: + """ + Get the thread mapping expression for a given set of argument indices. + + Parameters + ---------- + indices : list of PrimExpr + Indices for which to compute the thread mapping. + + Returns + ------- + PrimExpr + The computed thread expression for the provided indices. + """ + # Retrieve the forward iteration variables + forward_vars = self.get_forward_vars() + # The thread dimension (IterVar) is accessed via the `thread` property + forward_thread = self.thread + # Construct an IndexMap to map the provided args into the final thread index + index_map = IndexMap(initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None) + return index_map.map_indices(indices) + + def __repr__(self): + """ + String representation of the Fragment for debugging and logging. + + Returns + ------- + str + A string showing the thread dimension and the index dimension. + """ + return self._DebugOutput() + # return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" + + def is_equal(self, other: "Fragment") -> bool: + """ + Check if the current fragment is equal to another fragment. + """ + return _ffi_api.Fragment_is_equal(self, other) diff --git a/tilelang/original/tilelang/layout/gemm_sp.py b/tilelang/original/tilelang/layout/gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae836bc88eebe4aa319bc9283e0390c4e96246a --- /dev/null +++ b/tilelang/original/tilelang/layout/gemm_sp.py @@ -0,0 +1,161 @@ +"""Wrapping Layouts.""" + +# pylint: disable=invalid-name, unsupported-binary-operation +from __future__ import annotations +import tvm +import tilelang.language as T +import warnings + +from tilelang.contrib import nvcc +from math import prod + + +def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]: + res = [] + for x in basis: + res.append(index_1d % x) + index_1d //= x + return res + + +def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int): + """Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem. + + Args: + buffer: metadata buffer shape, for sm90 it should be a 8-bit type + mma_dtype: dtype of mma operand A, different dtypes result in different layout atom + block_k: tiling size along K dim, different block_ks results in different layout atom. + """ + + if block_k > 128: + block_k = 128 + # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 + warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) + if mma_dtype not in [ + T.float16, + T.bfloat16, + T.float32, + T.int8, + T.float8_e4m3, + T.float8_e4m3fn, + T.float8_e4m3fnuz, + T.float8_e5m2, + T.float8_e5m2fnuz, + ]: + raise NotImplementedError(f"Unsupported dtype: {mma_dtype}") + + if buffer.dtype not in [T.uint8, T.int8]: + raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}") + + bits_map = { + "float16": 16, + "bfloat16": 16, + "float32": 32, + "int8": 8, + "float8_e4m3": 8, + "float8_e4m3fn": 8, + "float8_e4m3fnuz": 8, + "float8_e5m2": 8, + "float8_e5m2fnuz": 8, + } + + # ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117 + # get atom layout according to mma dtype + BlockK = 512 // bits_map[mma_dtype] + if block_k % BlockK != 0: + raise ValueError(f"Tile K is too small, which should be at least {BlockK} for {mma_dtype}") + NumK = block_k // BlockK # block_k is MinTileShapeK + + def gen_stride(shape_ik, order): + stride_ik = [None for _ in range(len(shape_ik))] + order = [(i, o) for i, o in enumerate(order)] + order.sort(key=lambda x: x[1]) + accu_shape = 1 + for i, (o, _) in enumerate(order): + if i == 0: + stride_ik[o] = 1 + else: + stride_ik[o] = accu_shape + accu_shape *= shape_ik[o] + return stride_ik + + if bits_map[mma_dtype] == 32: # x // 8 is to convert bits into uint8 + shape_ik = [8, 2, 4, 8 // 8, 2, NumK] + stride_ik = gen_stride(shape_ik, [3, 1, 5, 0, 4, 2]) + shape_i, shape_k = shape_ik[:3], shape_ik[3:] + stride_i, stride_k = stride_ik[:3], stride_ik[3:] + elif bits_map[mma_dtype] == 16: + shape_ik = [8, 2, 4, 16 // 8, 2, NumK] + stride_ik = gen_stride(shape_ik, [3, 1, 5, 0, 4, 2]) + shape_i, shape_k = shape_ik[:3], shape_ik[3:] + stride_i, stride_k = stride_ik[:3], stride_ik[3:] + elif bits_map[mma_dtype] == 8: + shape_i, shape_k = [64], [block_k // 8] + stride_i, stride_k = [block_k // 8], [1] + else: + raise NotImplementedError(f"Unknown mma type {mma_dtype}") + + shape = buffer.shape + + # repeat to buffer size in col major + rep_i = (shape[0] + 63) // 64 + rep_k = (shape[1] + prod(shape_k) - 1) // prod(shape_k) + rep_i_stride = prod(shape_i + shape_k) + shape_i.append(rep_i) + stride_i.append(rep_i_stride) + rep_k_stirde = prod(shape_i + shape_k) + shape_k.append(rep_k) + stride_k.append(rep_k_stirde) + + def transform(i: int, k: int) -> int: + nonlocal shape_i, shape_k, stride_i, stride_k + i_decomposed = decompose_col_major(i, shape_i) + k_decomposed = decompose_col_major(k, shape_k) + i_offset = sum(i_decomposed[k] * stride_i[k] for k in range(len(i_decomposed))) + k_offset = sum(k_decomposed[k] * stride_k[k] for k in range(len(k_decomposed))) + return i_offset + k_offset + + return T.Layout(shape, transform) + + +def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): + """Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem. + ref: https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/sparse/_semi_structured_conversions.py#L5 + Args: + buffer: metadata buffer shape, for sm80 it should be a 16bit type + """ + + if mma_dtype in [T.float16, T.bfloat16] and buffer.dtype not in [T.uint16, T.int16]: + raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") + + if mma_dtype in ["float8_e4m3", "float8_e5m2", T.int8, T.uint8] and buffer.dtype not in [T.uint32, T.int32]: + raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") + + m, k = buffer.shape + group = 32 if buffer.dtype.bits == 16 else 16 + interweave = 4 if buffer.dtype.bits == 16 else 2 + + def ColumnMajorInterleaved(i: int, j: int) -> int: + i = i // group * group + (i % 8) * interweave + (i % group) // 8 + topright = (1 - (i % 2)) & (j % 2) + bottomleft = (i % 2) & (1 - (j % 2)) + i += topright - bottomleft + j -= topright - bottomleft + offset = (j // 2) * m * 2 + i * 2 + (j % 2) + return offset // k, offset % k + + return T.Layout(buffer.shape, ColumnMajorInterleaved) + + +def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = T.float16, arch: str | None = None, **extra_args): + if arch is None: + arch = nvcc.get_target_compute_version() + + compute_version = nvcc.parse_compute_version(arch) + + if compute_version >= (9, 0): + return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args) + elif compute_version >= (8, 0): + return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype) + else: + raise NotImplementedError(f"Unsupported architecture: {arch}") diff --git a/tilelang/original/tilelang/layout/layout.py b/tilelang/original/tilelang/layout/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd39e8de748199a1d2de77d6fa6ae55f96b4438 --- /dev/null +++ b/tilelang/original/tilelang/layout/layout.py @@ -0,0 +1,147 @@ +"""Wrapping Layouts.""" + +# pylint: disable=invalid-name, unsupported-binary-operation +import tvm_ffi +from tvm.ir import Node, Range +from tvm.tir import IterVar, Var, PrimExpr, IndexMap +from tilelang import _ffi_api + + +# Register the Layout class as a TVM object under the name "tl.Layout" +@tvm_ffi.register_object("tl.Layout") +class Layout(Node): + def __init__(self, shape, forward_fn): + """ + Initialize a Layout object. + + Parameters + ---------- + shape : list of int + The shape of the layout, defining the number of elements along each dimension. + forward_fn : function + A function that maps index variables to their computed forward index. + """ + forward_vars = [] # List to store IterVars corresponding to each shape dimension + + # Create an IterVar for each dimension in the shape + for idx, size in enumerate(shape): + # Define an IterVar over the range [0, size) with an associated variable name + iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0) + forward_vars.append(iv) + + # Extract the variable references from the IterVars + vars = [iv.var for iv in forward_vars] + + # Compute the forward index using the provided forward function + forward_index = forward_fn(*vars) + + # Ensure forward_index is a list (to handle cases where a single expression is returned) + if isinstance(forward_index, PrimExpr): + forward_index = [forward_index] + + # Call the FFI constructor to create the Layout object in C++ backend + self.__init_handle_by_constructor__(_ffi_api.Layout, forward_vars, forward_index) + + @property + def index(self): + """ + Property to retrieve the forward index of the layout. + + Returns + ------- + PrimExpr or List[PrimExpr] + The computed forward index expression(s). + """ + return _ffi_api.Layout_index(self) + + def get_input_shape(self): + """ + Get the input shape of the layout. + + Returns + ------- + List[int] + The shape of the input layout. + """ + return _ffi_api.Layout_input_shape(self) + + def get_output_shape(self): + """ + Get the output shape of the layout. + + Returns + ------- + List[int] + The shape of the output layout. + """ + return _ffi_api.Layout_output_shape(self) + + def get_forward_vars(self): + """ + Retrieve the iteration variables associated with the layout. + + Returns + ------- + List[IterVar] + A list of iteration variables that define the layout transformation. + """ + return _ffi_api.Layout_forward_vars(self) + + def get_forward_index(self): + return self.index + + def map_forward_index(self, indices: list[PrimExpr]) -> PrimExpr: + """ + Compute the forward index mapping for a given set of input indices. + + Parameters + ---------- + indices : list of PrimExpr + The input indices to be mapped to their corresponding output indices. + + Returns + ------- + PrimExpr + The mapped index expression for the provided input indices. + """ + # Retrieve the iteration variables used in the layout transformation + forward_vars = self.get_forward_vars() + + # Retrieve the computed forward index expressions + forward_indexes = self.index + + # Construct an IndexMap to map the input indices to the computed output indices + index_map = IndexMap( + initial_indices=forward_vars, # The original iteration variables + final_indices=forward_indexes, # The computed forward indices + inverse_index_map=None, # No inverse mapping provided at this stage + ) + + # Map the provided indices using the constructed index mapping + return index_map.map_indices(indices) + + def inverse(self) -> "Layout": + """ + Compute the inverse of the current layout transformation. + + Returns + ------- + Layout + A new Layout object representing the inverse transformation. + """ + return _ffi_api.Layout_inverse(self) + + def is_equal(self, other: "Layout") -> bool: + """ + Check if the current layout is equal to another layout. + + Parameters + ---------- + other : Layout + The layout to compare with. + """ + return _ffi_api.Layout_is_equal(self, other) + + def __repr__(self): + return self._DebugOutput() + # return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>" diff --git a/tilelang/original/tilelang/layout/swizzle.py b/tilelang/original/tilelang/layout/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..e083d756db52aa130c78a1a31959b1914b3c71c4 --- /dev/null +++ b/tilelang/original/tilelang/layout/swizzle.py @@ -0,0 +1,204 @@ +"""Wrapping Layouts.""" + +# pylint: disable=invalid-name, unsupported-binary-operation +from __future__ import annotations + +import tvm +from tvm.tir import Buffer, BufferLoad, BufferRegion +from tilelang import _ffi_api + + +def _get_buffer_info(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[Buffer, list[int], str]: + """ + Extract buffer, shape, and dtype from Buffer, BufferLoad, or BufferRegion. + + Args: + buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion + + Returns: + tuple: (buffer, shape, dtype) + """ + if isinstance(buffer_or_load_or_region, Buffer): + return buffer_or_load_or_region, buffer_or_load_or_region.shape, buffer_or_load_or_region.dtype + elif isinstance(buffer_or_load_or_region, (BufferLoad, BufferRegion)): + buf = buffer_or_load_or_region.buffer + return buf, buf.shape, buf.dtype + else: + raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") + + +def _get_stride_continuous(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: + """ + Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion. + + Args: + buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion + + Returns: + tuple: (stride, continuous) as integers + """ + _, shape, _ = _get_buffer_info(buffer_or_load_or_region) + stride = int(shape[-2]) + continuous = int(shape[-1]) + return stride, continuous + + +def _get_element_size(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> int: + """ + Get element size in bits from Buffer, BufferLoad, or BufferRegion. + + Args: + buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion + + Returns: + int: Element size in bits + """ + _, _, dtype = _get_buffer_info(buffer_or_load_or_region) + return int(tvm.DataType(dtype).bits) + + +# Use a stable swizzled layout to ensure consistent memory access patterns. +# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. +def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, k_major: bool = True, allow_pad: bool = True): + stride, continuous = _get_stride_continuous(buffer) + element_size = _get_element_size(buffer) + return _ffi_api.make_swizzled_layout( + stride, + continuous, + element_size, + k_major, + allow_pad, + ) + + +# for Volta Intrinsics +def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, is_a: bool = True, k_inner: bool = True): + stride, continuous = _get_stride_continuous(buffer) + return _ffi_api.make_volta_swizzled_layout( + stride, + continuous, + is_a, + k_inner, + ) + + +# for WGMMA Intrinsics +def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True): + stride, continuous = _get_stride_continuous(buffer) + element_size = _get_element_size(buffer) + if continuity is None: + continuity = continuous + return _ffi_api.make_wgmma_swizzled_layout( + stride, + continuous, + continuity, + element_size, + k_major, + ) + + +# for TCGEN05MMA Intrinsics +def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True): + stride, continuous = _get_stride_continuous(buffer) + element_size = _get_element_size(buffer) + if continuity is None: + continuity = continuous + return _ffi_api.make_tcgen05mma_swizzled_layout( + stride, + continuous, + continuity, + element_size, + k_major, + ) + + +# swizzle 128B +# args: buffer or (stride, continuous, element_size) +def make_full_bank_swizzled_layout(*args): + """ + Args: + args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size) + Examples: + make_full_bank_swizzled_layout(buffer) + make_full_bank_swizzled_layout(stride, continuous, element_size) + """ + if len(args) == 1: + stride, continuous = _get_stride_continuous(args[0]) + element_size = _get_element_size(args[0]) + elif len(args) == 3: + stride, continuous, element_size = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_full_bank_swizzled_layout( + stride, + continuous, + element_size, + ) + + +# swizzle 64B +# args: buffer or (stride, continuous, element_size) +def make_half_bank_swizzled_layout(*args): + """ + Args: + args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size) + Examples: + make_half_bank_swizzled_layout(buffer) + make_half_bank_swizzled_layout(stride, continuous, element_size) + """ + if len(args) == 1: + stride, continuous = _get_stride_continuous(args[0]) + element_size = _get_element_size(args[0]) + elif len(args) == 3: + stride, continuous, element_size = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_half_bank_swizzled_layout( + stride, + continuous, + element_size, + ) + + +# swizzle 32B +# args: buffer or (stride, continuous, element_size) +def make_quarter_bank_swizzled_layout(*args): + """ + Args: + args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size) + Examples: + make_quarter_bank_swizzled_layout(buffer) + make_quarter_bank_swizzled_layout(stride, continuous, element_size) + """ + if len(args) == 1: + stride, continuous = _get_stride_continuous(args[0]) + element_size = _get_element_size(args[0]) + elif len(args) == 3: + stride, continuous, element_size = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_quarter_bank_swizzled_layout( + stride, + continuous, + element_size, + ) + + +def make_linear_layout(*args): + """ + Args: + args: buffer/BufferLoad/BufferRegion or (stride, continuous) + Examples: + make_linear_layout(buffer) + make_linear_layout(stride, continuous) + """ + if len(args) == 1: + stride, continuous = _get_stride_continuous(args[0]) + elif len(args) == 2: + stride, continuous = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_linear_layout( + stride, + continuous, + ) diff --git a/tilelang/original/tilelang/libinfo.py b/tilelang/original/tilelang/libinfo.py new file mode 100644 index 0000000000000000000000000000000000000000..d82986b7534edc0d4cd0a92864e6da7d85818947 --- /dev/null +++ b/tilelang/original/tilelang/libinfo.py @@ -0,0 +1,35 @@ +import sys +import os + +from .env import TL_LIBS + + +def find_lib_path(name: str, py_ext=False): + """Find tile lang library + + Parameters + ---------- + name : str + The name of the library + + optional: boolean + Whether the library is required + """ + if py_ext: + lib_name = f"{name}.abi3.so" + elif sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): + lib_name = f"lib{name}.so" + elif sys.platform.startswith("win32"): + lib_name = f"{name}.dll" + elif sys.platform.startswith("darwin"): + lib_name = f"lib{name}.dylib" + else: + lib_name = f"lib{name}.so" + + for lib_root in TL_LIBS: + lib_dll_path = os.path.join(lib_root, lib_name) + if os.path.exists(lib_dll_path) and os.path.isfile(lib_dll_path): + return lib_dll_path + else: + message = f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" + "\n".join(TL_LIBS) + raise RuntimeError(message) diff --git a/tilelang/original/tilelang/math/__init__.py b/tilelang/original/tilelang/math/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef4a5bcabe3f3ba9769dad09bb7d1226bcc87b29 --- /dev/null +++ b/tilelang/original/tilelang/math/__init__.py @@ -0,0 +1,6 @@ +def next_power_of_2(x: int) -> int: + return 1 << (x - 1).bit_length() + + +def cdiv(a: int, b: int) -> int: + return (a + b - 1) // b diff --git a/tilelang/original/tilelang/profiler/__init__.py b/tilelang/original/tilelang/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94d350153caffa38e00523e2b49c1224c9a4afed --- /dev/null +++ b/tilelang/original/tilelang/profiler/__init__.py @@ -0,0 +1,277 @@ +"""The profiler and convert to torch utils""" + +from __future__ import annotations +from typing import Callable, Any, Literal +from functools import partial +import torch +from contextlib import suppress +from dataclasses import dataclass +import tvm +from tilelang.utils.tensor import ( + get_tensor_supply, + TensorSupplyType, + torch_assert_close, + is_float8_dtype, +) +from tilelang.engine.param import KernelParam +from tilelang.jit.adapter import BaseKernelAdapter +from tilelang.profiler.bench import do_bench + + +@dataclass +class Profiler: + """A profiler class for benchmarking and validating kernel implementations. + + Attributes: + params: List of kernel parameters defining the input/output specifications + result_idx: Indices indicating which parameters are output tensors + supply_type: Type of tensor supply to use (e.g., random, zeros, etc.) + adapter: Optional kernel adapter for interfacing with different backends + """ + + params: list[KernelParam] + result_idx: list[int] + supply_type: TensorSupplyType + adapter: BaseKernelAdapter | None = None + + def __post_init__(self): + """Initialize tensor supply after dataclass initialization""" + self.result_idx = self._legalize_result_idx(self.result_idx) + self.supply = get_tensor_supply(self.supply_type) + + def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]: + params = self.params + # result_idx is a list of indices of the output tensors + if result_idx is None: + result_idx = [] + elif isinstance(result_idx, int): + if result_idx > len(params) or result_idx < -len(params): + raise ValueError(f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") + if result_idx < 0: + result_idx = len(params) + result_idx + result_idx = [result_idx] + elif not isinstance(result_idx, list): + raise ValueError("result_idx should be a list of integers") + + return result_idx + + def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler: + self.adapter = adapter + return self + + def _get_inputs(self, with_output=False): + ins = [] + for i in range(len(self.params)): + if with_output or i not in self.result_idx: + ins.append(self.supply(self.params[i])) + return ins + + def _get_params(self, with_output=False): + params = [] + for i in range(len(self.params)): + if with_output or i not in self.result_idx: + params.append(self.params[i]) + return params + + def assert_allclose( + self, + reference_program: Callable, + input_tensors: list[torch.Tensor] | None = None, + atol: float = 1e-2, + rtol: float = 1e-2, + max_mismatched_ratio=0.01, + ): + """Validates kernel output against a reference implementation. + + Args: + reference_program: Reference implementation to compare against + input_tensors: Optional pre-generated input tensors + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + max_mismatched_ratio: Maximum allowed ratio of mismatched elements + """ + ins = self._get_inputs() if input_tensors is None else input_tensors + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + elif isinstance(lib_outs, tuple): + lib_outs = list(lib_outs) + elif lib_outs is None: + lib_outs = [] + + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + elif isinstance(ref_outs, tuple): + ref_outs = list(ref_outs) + elif ref_outs is None: + ref_outs = [] + + ref_tensors = ins + ref_outs + lib_tensors = ins + lib_outs + + assert len(lib_tensors) == len(ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" + # torch.set_printoptions(edgeitems=torch.inf) + for lhs, rhs in zip(lib_tensors, ref_tensors): + # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) + # total_elements = lhs.numel() + # num_not_close = (~close_mask).sum().item() + # percentage_not_close = (num_not_close / total_elements) * 100 + # print(f"{percentage_not_close:.2f}% of the elements are not close.") + # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") + if lhs is not None and rhs is not None: + # in case of numsplit template, the ref output may be None + # which means the value is invalid, so we skip the comparison + torch_assert_close( + lhs if not is_float8_dtype(lhs.dtype) else lhs.to(torch.float32), + rhs if not is_float8_dtype(rhs.dtype) else rhs.to(torch.float32), + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + base_name="tilelang", + ref_name="ref", + ) + + def manual_assert_close( + self, + reference_program: Callable, + input_tensors: list[torch.Tensor] | None = None, + manual_check_prog: Callable = None, + ): + """Validates kernel output against a reference implementation. + + Args: + reference_program: Reference implementation to compare against + input_tensors: Optional pre-generated input tensors + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + max_mismatched_ratio: Maximum allowed ratio of mismatched elements + """ + ins = self._get_inputs() if input_tensors is None else input_tensors + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + elif ref_outs is None: + ref_outs = [] + assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !" + torch.set_printoptions(edgeitems=torch.inf) + manual_check_prog(lib_outs, ref_outs) + + def assert_consistent(self, repeat=10): + """Checks for kernel consistency across multiple runs. + + Args: + repeat: Number of times to repeat the consistency check + """ + # Used to check no race condition inside the kernel + ins = self._get_inputs() + ref_outs = self.func(*ins) + + for _ in range(repeat): + lib_outs = self.func(*ins) + for lhs, rhs in zip(lib_outs, ref_outs): + assert torch.allclose(lhs, rhs), [ + "result is not consistent", + lhs, + rhs, + ] + + def run_once(self, func: Callable | None = None): + ins = self._get_inputs() + if not func: + func = self.__call__ + return func(*ins) + + def determine_profiler(self, func: Callable | None = None): + """Determines which profiler backend to use based on function type. + + Args: + func: Function to be profiled + profiler: Explicitly specified profiler type or "auto" for automatic detection + + Returns: + str: The determined profiler type ("torch" or "tvm") + """ + if isinstance(func, tvm.runtime.Module): + return "tvm" + else: + return "torch" + + def do_bench( + self, + func: Callable | None = None, + warmup: int = 25, + rep: int = 100, + n_warmup: int = 1, + n_repeat: int = 1, + input_tensors: list[torch.Tensor] = None, + backend: Literal["event", "cupti"] = "event", + quantiles: list[float] | None = None, + return_mode: Literal["min", "max", "mean", "median"] = "mean", + ) -> float: + """Benchmarks the execution time of a given function. + + Args: + func: Function to benchmark (uses adapter if None) + warmup: Warmup time in milliseconds + rep: Number of repetitions for timing + n_warmup: Number of warmup iterations + n_repeat: Number of timing iterations + profiler: Which profiling backend to use + input_tensors: Optional pre-generated input tensors + + Returns: + float: Average execution time in milliseconds + """ + profiler = self.determine_profiler(func) + if profiler == "torch": + if func is None: + assert self.adapter is not None, "benchmarking function should be provided" + func = self.adapter + ins = self._get_inputs() if input_tensors is None else input_tensors + bench_func = partial(func, *ins) + return do_bench( + bench_func, + warmup=warmup, + rep=rep, + _n_warmup=n_warmup, + _n_repeat=n_repeat, + quantiles=quantiles, + backend=backend, + return_mode=return_mode, + ) + elif profiler == "tvm": + assert func is not None, "func should not be None" + assert isinstance(func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" + + ins = self._get_inputs(with_output=True) if input_tensors is None else input_tensors + target = "cuda" + + with suppress(Exception): + target = self.mod.imported_modules[0].type_key + + assert target in ["cuda", "hip"], f"Unknown target: {target}" + + device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) + time_evaluator = self.mod.time_evaluator(self.mod.entry_name, device, number=rep, repeat=n_repeat) + # Transform Latency to ms + return time_evaluator(*ins).mean * 1e3 + else: + raise ValueError(f"Unknown profiler: {profiler}") + + @property + def func(self): + assert self.adapter is not None, "adapter should be provided" + return self.adapter + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.func(*args, **kwds) diff --git a/tilelang/original/tilelang/profiler/bench.py b/tilelang/original/tilelang/profiler/bench.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcb5043debc093a1449992ac5592d03b8a768e9 --- /dev/null +++ b/tilelang/original/tilelang/profiler/bench.py @@ -0,0 +1,204 @@ +"""Profiler and benchmarking utilities for PyTorch functions.""" + +from __future__ import annotations + +import os +import sys +from typing import Callable, Literal + +import torch + + +class suppress_stdout_stderr: + """Context manager to suppress stdout and stderr output. + + Source: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/testing/bench.py + """ + + def __enter__(self): + # Open null device files + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + # Save original file descriptors + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + # Save original stdout/stderr objects + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + # Redirect file descriptors and streams to null device + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + + return self + + def __exit__(self, *_): + # Restore original stdout/stderr objects + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + # Restore original file descriptors + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + # Close duplicated file descriptors + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + # Close null device files + self.outnull_file.close() + self.errnull_file.close() + + +IS_CUDA = torch.cuda.is_available() +device = "cuda:0" if IS_CUDA else "mps:0" +Event = torch.cuda.Event if IS_CUDA else torch.mps.Event + + +def do_bench( + fn: Callable, + warmup: float = 25, + rep: float = 100, + _n_warmup: int = 0, + _n_repeat: int = 0, + quantiles: list[float] | None = None, + fast_flush: bool = True, + backend: Literal["event", "cupti"] = "event", + return_mode: Literal["min", "max", "mean", "median"] = "mean", +) -> float | list[float]: + """Benchmark the runtime of a PyTorch function with L2 cache management. + + This function provides accurate GPU kernel timing by: + - Clearing L2 cache between runs for consistent measurements + - Auto-calculating warmup and repeat counts based on kernel runtime + - Supporting multiple profiling backends (CUDA events or CUPTI) + - Offering flexible result aggregation (mean/median/min/max/quantiles) + + Args: + fn: Function to benchmark + warmup: Target warmup time in milliseconds (default: 25) + rep: Target total benchmark time in milliseconds (default: 100) + _n_warmup: Manual override for warmup iterations (default: 0 = auto) + _n_repeat: Manual override for benchmark iterations (default: 0 = auto) + quantiles: Performance percentiles to compute (e.g., [0.5, 0.95]) + fast_flush: Use faster L2 cache flush with int32 vs int8 (default: True) + backend: Profiler backend - "event" (CUDA events) or "cupti" (default: "event") + return_mode: Result aggregation method - "mean", "median", "min", or "max" + + Returns: + Runtime in milliseconds (float) or list of quantile values if quantiles specified + """ + assert return_mode in ["min", "max", "mean", "median"], f"Invalid return_mode: {return_mode}" + + # Initial function call and synchronization + fn() + torch.cuda.synchronize() + + # Create L2 cache flush buffer (256 MB) + # Fast flush uses int32 (4 bytes), regular uses int8 (1 byte) + cache_size = int(256e6 // 4) if fast_flush else int(256e6) + cache_dtype = torch.int if fast_flush else torch.int8 + cache = torch.empty(cache_size, dtype=cache_dtype, device="cuda") + + # Estimate kernel runtime with 5 iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + start_event.synchronize() + end_event.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # Calculate warmup and repeat counts (minimum 1 iteration each) + n_warmup = _n_warmup if _n_warmup > 0 else max(1, int(warmup / estimate_ms)) + n_repeat = _n_repeat if _n_repeat > 0 else max(1, int(rep / estimate_ms)) + + # Warmup phase + for _ in range(n_warmup): + fn() + + # Benchmarking phase + if backend == "event": + return _bench_with_cuda_events(fn, cache, n_repeat, quantiles, return_mode) + elif backend == "cupti": + return _bench_with_cupti(fn, cache, n_repeat) + else: + raise ValueError(f"Unknown profiler backend: {backend}") + + +def _bench_with_cuda_events( + fn: Callable, + cache: torch.Tensor, + n_repeat: int, + quantiles: list[float] | None, + return_mode: str, +) -> float | list[float]: + """Benchmark using CUDA events for timing.""" + # Create timing events + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + + # Run benchmark iterations + for i in range(n_repeat): + cache.zero_() # Clear L2 cache + start_events[i].record() + fn() + end_events[i].record() + + # Synchronize and collect timings + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_events, end_events)], + dtype=torch.float, + ) + + # Return quantiles if requested + if quantiles is not None: + quantile_values = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + return quantile_values[0] if len(quantile_values) == 1 else quantile_values + + # Return aggregated result + return getattr(torch, return_mode)(times).item() + + +def _bench_with_cupti( + fn: Callable, + cache: torch.Tensor, + n_repeat: int, +) -> float: + """Benchmark using CUPTI profiler for detailed kernel timing.""" + with suppress_stdout_stderr(): + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + schedule=schedule, + ) + + with profiler: + for _ in range(2): + for _ in range(n_repeat): + cache.zero_() + fn() + profiler.step() + + # Calculate average kernel time, excluding cache-clearing overhead + total_cuda_time = 0.0 + excluded_time = 0.0 + excluded_kernels = "at::native::vectorized_elementwise" + + for event in profiler.key_averages(): + total_cuda_time += event.self_device_time_total + if excluded_kernels in event.key: + excluded_time += event.self_device_time_total + + kernel_time_us = (total_cuda_time - excluded_time) / n_repeat + return kernel_time_us * 1e-3 # Convert microseconds to milliseconds diff --git a/tilelang/original/tilelang/quantize/__init__.py b/tilelang/original/tilelang/quantize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1bb8daa51eb8b3978a2efdb696c0ea1bb284ac7 --- /dev/null +++ b/tilelang/original/tilelang/quantize/__init__.py @@ -0,0 +1,18 @@ +from .quantization import ( + _tir_packed_int_to_int_convert, # noqa: F401 + _tir_packed_to_signed_convert, # noqa: F401 + _tir_packed_to_unsigned_convert, # noqa: F401 + _tir_packed_to_fp4_to_f16, # noqa: F401 + _tir_u8_to_f8_e4m3_to_f16, # noqa: F401 + _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 + _tir_u8_to_f4_to_bf16, # noqa: F401 +) + +from .utils import ( + gen_quant4, # noqa: F401 + general_compress, # noqa: F401 + interleave_weight, # noqa: F401 +) + +from .lop3 import get_lop3_intrin_group # noqa: F401 +from .mxfp import get_mxfp_intrin_group # noqa: F401 diff --git a/tilelang/original/tilelang/quantize/lop3.py b/tilelang/original/tilelang/quantize/lop3.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1f457d1ed2b692d946e1cbf1de476228fb32df --- /dev/null +++ b/tilelang/original/tilelang/quantize/lop3.py @@ -0,0 +1,1199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Literal +from tilelang import language as T + +decode_i4_to_f16 = """ +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N); +} +""" + +decode_i4_to_f16_scale = """ +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale); +} + +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale); +} + +""" + +decode_i4_to_f16_scale_offset = """ +template +__device__ void decode_i4b_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const int offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } + #pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i4s_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_offset(_i4s, B_local_decode, N, scale, offset); +} + +template +__device__ void decode_i4u_to_f16_scale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_offset(_i4u, B_local_decode, N, scale, offset); +} + +""" + +decode_i4_to_f16_scale_zeros_original = """ +template +__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i4_to_f16_scale_zeros_original_offset = """ +template +__device__ void decode_i4b_to_f16_zeros_original_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T3 const zeros_l = *zeros; + T3 const zeros_r = *(zeros + offset); + uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l); + uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_original_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_zeros_original_offset(_i4u, B_local_decode, N, scale, zeros, offset); +} +""" + +decode_i4_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} + +""" + +decode_i4_to_f16_scale_zeros_rescale_offset = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T3 const zeros_l = *zeros; + T3 const zeros_r = *(zeros + offset); + uint const packed_zeros_l = 0x80008000 | __pack_half2(zeros_l, zeros_l); + uint const packed_zeros_r = 0x80008000 | __pack_half2(zeros_r, zeros_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(packed_zeros_l)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(packed_zeros_r)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_rescale_offset(_i4u, B_local_decode, N, scale, zeros, offset); +} + +""" + +decode_i4_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + int16_t const zero_r = *((int16_t*)zeros); + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i4_to_f16_scale_zeros_quantized_offset = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T1 *qzeros = nullptr, const int scale_offset = 0, const int qzeros_offset = 0, const int group_offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + + T3 const scale_l = *scale; + T3 const scale_r = *(scale + scale_offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + + const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4; + + T1 const qzeros_l = *qzeros; + T1 const qzeros_r = *(qzeros + qzeros_offset); + int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf); + int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf); + + uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l); + uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + } + #pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_l)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_r)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized_offset(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, storage_dtype *qzeros = nullptr, const int scale_offset = 0, const int zero_offset = 0, const int group_offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_quantized_offset(_i4u, B_local_decode, N, scale, qzeros, scale_offset, zero_offset, group_offset); +} +""" + +decode_i2_to_f16 = """ +template +__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N); +} +""" + +decode_i2_to_f16_scale = """ +template +__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N); +} + +template +__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N); +} +""" + +decode_i2_to_f16_scale_zeros_original_offset = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_original_offset(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int offset = 0, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + + T3 const zeros_l = *zeros; + T3 const zeros_r = *(zeros + offset); + uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l); + uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r); + + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } + #pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_original_offset(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int offset = 0, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, offset, N); +} +""" + +decode_i2_to_f16_scale_zeros_original = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_original(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_original(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, N); +} +""" + +decode_i2_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_rescale(_i2u, B_local_decode, scale, zeros, N); +} +""" + +decode_i2_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + int16_t const zero_r = *((int16_t*)zeros); + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1_to_f16 = """ +/* +Kind 0: original +Kind 1: rescale +Kind 2: quantized +# documents for zeros_mode: +# original: target = (dequantize_weight - zero_point) * scale +# rescale: target = dequantize_weight * scale - zero_point +# quantized: target = (dequantize_weight - dequantize_zeros) * scale +# Notice: only support "original" and "rescale" now +zeros_mode: Literal["original", "rescale", "quantized"] = "original" +*/ +template +__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (isSigned) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + } + if constexpr (withZeros && ZerosKind == 0) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } + if constexpr (withScaling) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } + if constexpr (withZeros && ZerosKind == 1) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } + } +} + +template +__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1s, B_local_decode, N); +} + +template +__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1u, B_local_decode, N); +} +""" + +decode_i1_to_f16_scale = """ +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +""" + +decode_i1_to_f16_scale_zeros_original = """ +template +__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1s_to_i8s = """template +__device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int i8s[4]; + // vector load + *reinterpret_cast(i8s) = *reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint TRANSFORM_SUBTRACT = 0xffffffff; // for signed int 2x - 1 + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vadd4(i8s[i], i8s[i]); + i8s[i] = __vadd4(i8s[i], TRANSFORM_SUBTRACT); + } + *reinterpret_cast(_i8s) = *reinterpret_cast(i8s); +} + +template +__device__ void decode_i1u_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int *i8s = reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint MEDIAN_NUM = 0x00000000; + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} + +""" + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i4s_to_i8s = """template +__device__ void decode_i4s_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = 0x07070707; +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM); + } +} + +template +__device__ void decode_i4u_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i2s_to_i4s = r""" +template +__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) +{ + uint *i4s = reinterpret_cast(_i4s); + uint *i2b = reinterpret_cast(_i2b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000; + +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i4s[i]) + : "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + // TODO(lei): uint4 sub should be enhanced. + // 0x03 0x03 0x03 0x03 + // i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; + } + } +} + +template +__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4u, B_local_decode, N); +} +""" + + +def get_lop3_intrin_group( + out_dtype: Literal[T.float16, T.int8, T.int4], + source_format: Literal[T.int, T.uint] = T.uint, + source_bit: int = 4, + storage_dtype: Literal[T.int32, T.int8] = T.int8, + with_scaling: bool = False, + with_zeros: bool = False, + zeros_mode: Literal["original", "rescale", "quantized"] = "original", + storage_scope: str = "local", +) -> dict[str, str]: + """ + This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. + LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of + intrinsic operations that can be performed on these inputs. This function retrieves and returns this group. + + Parameters + ---------- + in_dtype : Literal[T.int8] + The data type of the input. It should be "int8". + + out_dtype : Literal[T.float16, T.int8, T.int4] + The data type of the output. It can be either "float16" or "int8" or "int4". + + storage_nbit : int, optional + The number of bits used for storage. By default, it is 4. + + with_scale : bool, optional + A boolean parameter that indicates whether scaling should be applied. By default, it is False. + + with_zeros : bool, optional + A boolean parameter that indicates whether zeros should be used. By default, it is False. + + zeros_mode : Literal["original", "rescale", "quantized"], optional + The mode of zeros. It can be either "original", "rescale", or "quantized". By default, it is "original". + + storage_scope : Literal["local", "warp"], optional + The scope of the storage. It can be either "local" or "warp". By default, it is "local". + + Returns + ------- + Dict[str, str] + A dictionary mapping the names of the intrinsics to their corresponding implementations. + """ + out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype) + assert out_dtype in [T.float16, T.int8, T.int4], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ." + + dtype_mapping = {T.float16: "f16", T.int4: "i4", T.int8: "i8", T.int32: "i32"} + target_dtype = dtype_mapping[out_dtype] + + if source_format not in [T.int, T.uint]: + raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}, {type(source_format)}.") + if with_zeros and source_format == T.int: + raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}") + + import_c_map = { + "i4_to_f16": decode_i4_to_f16, + "i2_to_f16": decode_i2_to_f16, + "i1_to_f16": decode_i1_to_f16, + "i4_to_f16_scale": decode_i4_to_f16_scale, + "i4_to_f16_scale_offset": decode_i4_to_f16_scale_offset, + "i2_to_f16_scale": decode_i2_to_f16_scale, + "i1_to_f16_scale": decode_i1_to_f16_scale, + "i4_to_f16_scale_zeros_original": decode_i4_to_f16_scale_zeros_original, + "i4_to_f16_scale_zeros_original_offset": decode_i4_to_f16_scale_zeros_original_offset, + "i2_to_f16_scale_zeros_original": decode_i2_to_f16_scale_zeros_original, + "i1_to_f16_scale_zeros_original": decode_i1_to_f16_scale_zeros_original, + "i4_to_f16_scale_zeros_rescale": decode_i4_to_f16_scale_zeros_rescale, + "i4_to_f16_scale_zeros_rescale_offset": decode_i4_to_f16_scale_zeros_rescale_offset, + "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, + "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, + "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, + "i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized, + "i4_to_f16_scale_zeros_quantized_offset": decode_i4_to_f16_scale_zeros_quantized_offset, + "i1_to_i8": decode_i1s_to_i8s, + "i2_to_i8": decode_i2s_to_i8s, + "i4_to_i8": decode_i4s_to_i8s, + "i2_to_i4": decode_i2s_to_i4s, + } + key = f"i{source_bit}_to_{target_dtype}" + if with_scaling: + key += "_scale" + if with_zeros: + key += f"_zeros_{zeros_mode}" + + is_ladder_stage3 = (storage_scope == "warp") and with_scaling + if is_ladder_stage3: + key += "_offset" + + if out_dtype == T.float16: + d4f = "f16" + elif out_dtype == T.int8: + d4f = "i8s" + elif out_dtype == T.int4: + d4f = "i4s" + else: + raise ValueError(f"Unsupported target dtype: {target_dtype}") + source_symbol = "u" if source_format == T.uint else "s" + func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}" + if with_scaling: + func_name += "_scale" + if with_zeros: + func_name += f"_zeros_{zeros_mode}" + if is_ladder_stage3: + func_name += "_offset" + + return { + "func_name": func_name, + "c_source": import_c_map[key], + } diff --git a/tilelang/original/tilelang/quantize/mxfp.py b/tilelang/original/tilelang/quantize/mxfp.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7100a6298f9c89d153f18e405229bc9b79f64d --- /dev/null +++ b/tilelang/original/tilelang/quantize/mxfp.py @@ -0,0 +1,105 @@ +from typing import Literal +from tilelang import language as T + +# Implementation asm for fp4 to bf16, using twiddling +# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18 +decode_f4_to_bf16_twiddling = """ +// N should be the number of elements processed by one thread +template +__device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, const int N = 8) { + #pragma unroll + for (int i = 0; i < N; ++i) { + uint B_dequantize_local_vec[4]; + uint tmp, bias, d0, d1, d2, d3, d4, d5, d6; + asm volatile( + // To handle the endianness issue + "prmt.b32 %13, %4, 0, 0x0123;" + "mov.b32 %12, 0x7e807e80;" + "and.b32 %0, %13, 0b10000001110000001000000111000000;" + "mul.bf16x2 %0, %0, %12;" + "shl.b32 %1, %13, 3;" + "and.b32 %1, %1, 0b10000001110000001000000111000000;" + "mul.bf16x2 %1, %1, %12;" + "shl.b32 %2, %13, 6;" + "and.b32 %2, %2, 0b10000001110000001000000111000000;" + "mul.bf16x2 %2, %2, %12;" + "shl.b32 %5, %13, 1;" + "and.b32 %6, %5, 0b10000000000000001000000000000000;" + "shr.b32 %7, %13, 3;" + "and.b32 %8, %7, 0b00000001100000000000000110000000;" + "or.b32 %9, %6, %8;" + "shr.b32 %10, %13, 7;" + "and.b32 %11, %10, 0b00000000010000000000000001000000;" + "or.b32 %3, %9, %11;" + "mul.bf16x2 %3, %3, %12;" + :"=r"(B_dequantize_local_vec[0]) + ,"=r"(B_dequantize_local_vec[1]) + ,"=r"(B_dequantize_local_vec[2]) + ,"=r"(B_dequantize_local_vec[3]) + :"r"(*(uint*)&B_local[i << 2]), "r"(d0), "r"(d1), "r"(d2), "r"(d3), "r"(d4), "r"(d5), "r"(d6), "r"(bias), "r"(tmp) + ); + for (int j = 0; j < 4; ++j) { + // Pay attention to the big-endianness issue + B_local_decode[(i << 3) + j] = reinterpret_cast(&B_dequantize_local_vec[j])[1]; + B_local_decode[(i << 3) + j + 4] = reinterpret_cast(&B_dequantize_local_vec[j])[0]; + } + } + // Check if the synchronization is needed +} +""" + + +def get_mxfp_intrin_group( + out_dtype: Literal[T.float16, T.bfloat16] = T.bfloat16, + source_format: Literal[T.int, T.uint] = T.uint, + source_bit: int = 4, + storage_dtype: Literal[T.int32, T.int8, T.uint8] = T.uint8, + use_twiddling: bool = False, +) -> dict[str, str]: + """ + Return metadata for an MXFP decoding intrinsic: function name and C source string. + + Validates the requested output dtype, source format, and storage dtype, then constructs + a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when + use_twiddling is True) to select the corresponding C source snippet and a matching + function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with + `_twiddling`). + + Parameters: + out_dtype: Target floating-point type for decoded values; either T.float16 or T.bfloat16. + source_format: Integer source representation; "int" or "uint". + source_bit: Bit width of the packed source format (e.g., 4). + storage_dtype: Underlying storage integer dtype (one of T.int32, T.int8, T.uint8). + use_twiddling: When True, select the twiddling variant of the decoding intrinsic. + + Returns: + A dict with: + - "func_name": the generated C function name string for the requested decode intrinsic. + - "c_source": the C source string for that intrinsic. + + Raises: + AssertionError: if out_dtype, source_format, or storage_dtype are not supported. + KeyError: if the constructed key does not match any available C source implementation. + """ + out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype) + assert out_dtype in [T.float16, T.bfloat16], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." + assert source_format in [T.int, T.uint], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." + assert storage_dtype in [T.int32, T.int8, T.uint8], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." + + dtype_map = {T.float16: "f16", T.bfloat16: "bf16"} + key = f"fp{source_bit}_to_{dtype_map[out_dtype]}" + if use_twiddling: + key += "_twiddling" + + import_c_map = { + "fp4_to_bf16_twiddling": decode_f4_to_bf16_twiddling, + } + + func_name = f"decode_fp{source_bit}_to_{dtype_map[out_dtype]}" + if use_twiddling: + func_name += "_twiddling" + + return { + "func_name": func_name, + "c_source": import_c_map[key], + } diff --git a/tilelang/original/tilelang/quantize/quantization.py b/tilelang/original/tilelang/quantize/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..74a545f258c419b463bf39734504699458b6162a --- /dev/null +++ b/tilelang/original/tilelang/quantize/quantization.py @@ -0,0 +1,294 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. +# pylint: disable=invalid-name,missing-function-docstring,unused-variable +"""TIR computation utilities for quantization.""" + +from tilelang import language as T +from tilelang import tvm as tvm +from tvm import tir + + +# fmt: off +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, + dtype: str): + """ + Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale. + + This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns + a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa. + + Behavior: + - Validates `nbit == 4`, `dtype == T.bfloat16`, and `val.dtype == T.uint8` (AssertionError if violated). + - Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`). + - Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0. + - Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent, + and clamps the result to the 8-bit exponent range (0..255). + - Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and + returns it reinterpreted as `bfloat16`. + + Parameters: + - nbit: must be 4 (width of the packed field). + - val: uint8 expression containing packed fields. + - pos: index of the field within `val` (0-based); used to compute the bit shift. + - scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression). + - dtype: must be T.bfloat16. + + Returns: + - A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we use the max function to limit the exponential part to 8 bits + e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret(T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) + | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16)) + return val_bf16 + +def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): + """ + Convert two float32 values to bfloat16 and pack them into a single uint32. + + The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even + by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are + packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits. + + Parameters: + v0 (tir.PrimExpr): First float32 value to convert and pack. + v1 (tir.PrimExpr): Second float32 value to convert and pack. + round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True). + + Returns: + tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits). + """ + mask = tir.const((1 << 16) - 1, T.uint32) + res = [] + for data in [v0, v1]: + u32_val = tir.reinterpret(T.uint32, data) + if round_to_even: + rounding_bias = ((u32_val >> tir.const(16, T.uint32)) + & tir.const(1, T.uint32)) + tir.const(0x7FFF, T.uint32) + u32_val += rounding_bias + res.append((u32_val >> tir.const(16, T.uint32)) & mask) + return res[0] | (res[1] << tir.const(16, T.uint32)) + + +def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): + mask = tir.const((1 << 16) - 1, T.uint32) + x0 = x & mask + x1 = (x >> 16) & mask + return (tir.reinterpret(T.float32, x << tir.const(16, T.uint32)) for x in [x0, x1]) + + +def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == T.uint32 + mask = tvm.tir.const((1 << nbit) - 1, T.uint32) + return tir.Cast(dtype, (val >> (pos * nbit).astype(T.uint32)) & mask) + + +def _tir_packed_uint_to_uint_to_float(storage_nbit: int): + storage_dtype = "uint" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) - 1 + return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, T.int32) + unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32)) + + return f_convert + + +def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == T.float32 + val_u32 = tir.reinterpret(T.uint32, val) + # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) + # e_f32 == 120 -> e_f4 = 1 + # e_f32 < 120 -> e_f4 = 0 + m_h = (val_u32 >> tir.const(22, T.uint32)) & tir.const(1, T.uint32) + e_f32 = (val_u32 >> tir.const(23, T.uint32)) & tir.const(255, T.uint32) + s = (val_u32 >> tir.const(31, T.uint32)) + e_f4 = tir.Select( + e_f32 > tir.const(120, T.uint32), + tir.Min(e_f32 - tir.const(120, T.uint32) + m_h, tir.const(7, T.uint32)), + tir.Select(e_f32 == tir.const(120, T.uint32), tir.const(1, T.uint32), + tir.const(0, T.uint32))) + return (s << tir.const(3, T.uint32)) | e_f4 + + +def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == T.float16 + val_u32 = tir.Cast(T.uint32, tir.reinterpret(T.uint16, val)) + m_h = (val_u32 >> tir.const(9, T.uint32)) & tir.const(1, T.uint32) + e_f16 = (val_u32 >> tir.const(10, T.uint32)) & tir.const(31, T.uint32) + s = (val_u32 >> tir.const(15, T.uint32)) + e_f4 = tir.Select( + e_f16 > tir.const(8, T.uint32), + tir.Min(e_f16 - tir.const(8, T.uint32) + m_h, tir.const(7, T.uint32)), + tir.Select(e_f16 == tir.const(8, T.uint32), tir.const(1, T.uint32), tir.const(0, T.uint32))) + return (s << tir.const(3, T.uint32)) | e_f4 + + +def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == T.float32 + assert val.dtype == T.uint32 + # e_f4 == 0 -> e_f32 = 0 + # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 + mask = tvm.tir.const((1 << nbit) - 1, T.uint32) + f4 = (val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & mask + s = f4 >> tir.const(3, T.uint32) + e_f4 = f4 & tir.const(7, T.uint32) + e_f32 = e_f4 | tir.const(120, T.uint32) + val_f32 = tir.reinterpret(T.float32, + (e_f32 | (s << tir.const(8, T.uint32))) << tir.const(23, T.uint32)) + return tir.Select(e_f4 == tir.const(0, T.uint32), tir.const(0, T.float32), val_f32) + + +def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == T.float16 + assert val.dtype == T.uint32 + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 + mask = tvm.tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = f4 & tir.const(7, T.uint16) + e_f16 = e_f4 | tir.const(8, T.uint16) + val_f16 = tir.reinterpret(T.float16, + ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16)).astype(T.uint16)) + return tir.Select(e_f4 == tir.const(0, T.uint16), tir.const(0, T.float16), val_f16) + +def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + f4 = ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(storage_dtype) + f4 = (val >> (pos.astype(storage_dtype) * tir.const(nbit, storage_dtype))) & mask + s = f4 >> tir.const(3, storage_dtype) + e_f4 = f4 & tir.const(7, storage_dtype) + e_f16 = e_f4 | tir.const(8, storage_dtype) + val_f16 = tir.reinterpret(T.float16, + ((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype(T.uint16)) + return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, T.float16), val_f16) + + return f_convert + +def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == T.float16 + s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16) + e4 = val & tir.const(0x40, T.uint16) + prefix = tir.Select(e4 == tir.const(0, T.uint16), tir.const(0x2000, T.uint16), + tir.const(0x4000, T.uint16)) + e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | prefix + return tir.reinterpret(T.float16, s_f16 | e_f16) + + +def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == T.float16 + s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16) + e4 = val & tir.const(0x40, T.uint16) + e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | (e4 << tir.const(8, T.uint16)) | (e4 << tir.const(7, T.uint16)) + e_f16 = e_f16 ^ tir.const(0x2000, T.uint16) + return tir.reinterpret(T.float16, s_f16 | e_f16) + + +def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == T.float16 + return tir.reinterpret("float8_e5m2", val).astype(T.float16) + + +def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) + return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) + + return f_convert + + +def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr, + dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype) + + return f_convert + + +def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, T.int32) + unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32)) + + return f_convert + + +# fmt: on diff --git a/tilelang/original/tilelang/quantize/utils.py b/tilelang/original/tilelang/quantize/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d092a0bab95b7c1a6e34fbbdff61ac76da340d2 --- /dev/null +++ b/tilelang/original/tilelang/quantize/utils.py @@ -0,0 +1,125 @@ +def gen_quant4(k, n, groupsize=-1): + import torch + import torch.nn as nn + + maxq = 2**4 + w = torch.randn((k, n), dtype=torch.half, device="cpu") + + original_w = w.clone() + + if groupsize == -1: + groupsize = k + + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + + # Quantize. + w = torch.round(w / s).int() + + # Unsigned storage. + w += (maxq) // 2 + + w = torch.clamp(w, 0, maxq) + + # Dequantize. + ref = (w - (maxq) // 2).half() * s + + if groupsize != -1: + + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((k, n)).contiguous() + return w + + ref = reshape(ref) + w = reshape(w) + + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(k, n, bias=False) + linear.weight.data = ref.t() + + return original_w, linear, s, (w - (maxq) // 2) + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None): + import torch + + if storage_dtype is None: + storage_dtype = torch.int8 + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == torch.float16: + lowprecision_weight = lowprecision_weight.to(torch.int8) + int8_weight = torch.zeros( + (*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte), + dtype=torch.int8, + device=lowprecision_weight.device, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << (source_bits * k)).to(torch.int8) + + return int8_weight.to(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype="float16"): + """Interleave the weight to the target data type. + + Args: + qweight (_type_): _description_ + nbits (int, optional): _description_. Defaults to 4. + target_dtype (str, optional): _description_. Defaults to "float16". + + Returns: + _type_: _description_ + + Example: + qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda() + interleave_weight(qweight, 4, "float16") + """ + import torch + + assert target_dtype in ["float16", "int8"] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(torch.int32) + new_qweight = torch.zeros_like(qweight) + bits_stride = 8 if target_dtype == "int8" else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == "int8": + # special handling for 1b interleave + n16_weight = new_qweight & torch.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(torch.int8) + elif nbits == 2 and target_dtype == "float16": + n8_weight = new_qweight & torch.int32(0xFF0000FF) + n8_weight |= ((new_qweight & torch.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & torch.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(torch.int8) + elif nbits == 1 and target_dtype == "float16": + n8_weight = new_qweight & torch.int32(0xF000000F) + n8_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 8 + n8_weight |= ((new_qweight & torch.int32(0x00000F00)) >> 8) << 16 + n8_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24 + n8_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4 + n8_weight |= ((new_qweight & torch.int32(0x00F00000)) >> 20) << 12 + n8_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 20 + return n8_weight.view(torch.int8) + + return new_qweight.view(torch.int8) diff --git a/tilelang/original/tilelang/testing/__init__.py b/tilelang/original/tilelang/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..635fad365ce4d867f255531f01c0b9c2dfefe616 --- /dev/null +++ b/tilelang/original/tilelang/testing/__init__.py @@ -0,0 +1,121 @@ +import sys +import inspect +import pytest +import random +import torch +import numpy as np +from tilelang.contrib import nvcc +from tvm.testing.utils import requires_cuda, requires_package, requires_llvm, requires_metal, requires_rocm, _compose + +from tilelang.utils.tensor import torch_assert_close as torch_assert_close + +__all__ = [ + "requires_package", + "requires_cuda", + "requires_metal", + "requires_rocm", + "requires_llvm", + "main", + "requires_cuda_compute_version", +] + [f"requires_cuda_compute_version_{op}" for op in ("ge", "gt", "le", "lt", "eq")] + + +# pytest.main() wrapper to allow running single test file +def main(): + test_file = inspect.getsourcefile(sys._getframe(1)) + sys.exit(pytest.main([test_file] + sys.argv[1:])) + + +def set_random_seed(seed: int = 42) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def requires_cuda_compute_version(major_version, minor_version=0, mode="ge"): + """Mark a test as requiring at least a compute architecture + + Unit test marked with this decorator will run only if the CUDA + compute architecture of the GPU is at least `(major_version, + minor_version)`. + + This also marks the test as requiring a cuda support. + + Parameters + ---------- + major_version: int + + The major version of the (major,minor) version tuple. + + minor_version: int + + The minor version of the (major,minor) version tuple. + + mode: str + + The mode of the comparison. + - "ge": greater than or equal to + - "gt": greater than + - "le": less than or equal to + - "lt": less than + """ + min_version = (major_version, minor_version) + try: + arch = nvcc.get_target_compute_version() + compute_version = nvcc.parse_compute_version(arch) + except ValueError: + # No GPU present. This test will be skipped from the + # requires_cuda() marks as well. + compute_version = (0, 0) + + min_version_str = ".".join(str(v) for v in min_version) + compute_version_str = ".".join(str(v) for v in compute_version) + + def compare(compute_version, min_version, mode) -> bool: + if mode == "ge": + return compute_version >= min_version + elif mode == "gt": + return compute_version > min_version + elif mode == "le": + return compute_version <= min_version + elif mode == "lt": + return compute_version < min_version + elif mode == "eq": + return compute_version == min_version + else: + raise ValueError(f"Invalid mode: {mode}") + + requires = [ + pytest.mark.skipif( + not compare(compute_version, min_version, mode), + reason=f"Requires CUDA compute {mode} {min_version_str}, but have {compute_version_str}", + ), + *requires_cuda.marks(), + ] + + def inner(func): + return _compose([func], requires) + + return inner + + +def requires_cuda_compute_version_ge(major_version, minor_version=0): + return requires_cuda_compute_version(major_version, minor_version, mode="ge") + + +def requires_cuda_compute_version_gt(major_version, minor_version=0): + return requires_cuda_compute_version(major_version, minor_version, mode="gt") + + +def requires_cuda_compute_version_eq(major_version, minor_version=0): + return requires_cuda_compute_version(major_version, minor_version, mode="eq") + + +def requires_cuda_compute_version_lt(major_version, minor_version=0): + return requires_cuda_compute_version(major_version, minor_version, mode="lt") + + +def requires_cuda_compute_version_le(major_version, minor_version=0): + return requires_cuda_compute_version(major_version, minor_version, mode="le") diff --git a/tilelang/original/tilelang/tileop/__init__.py b/tilelang/original/tilelang/tileop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7798a05104bc70bbe4eb62a3b2ac3958217ed7 --- /dev/null +++ b/tilelang/original/tilelang/tileop/__init__.py @@ -0,0 +1,3 @@ +from .base import GemmWarpPolicy # noqa: F401 +from .gemm import GemmPy # noqa: F401 +from .gemm_sp import GemmSPPy # noqa: F401 diff --git a/tilelang/original/tilelang/tileop/base.py b/tilelang/original/tilelang/tileop/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b51b3ac5d62bb4dc7a93fdae40b77cd2f5246e --- /dev/null +++ b/tilelang/original/tilelang/tileop/base.py @@ -0,0 +1,185 @@ +from __future__ import annotations +from enum import IntEnum + + +class GemmWarpPolicy(IntEnum): + """ + Enumeration for GEMM Warp Partitioning Policies. + """ + + Square = 0 # Balance warps evenly in a "square" aspect ratio. + FullRow = 1 # Assign all warps to rows. + FullCol = 2 # Assign all warps to columns. + + def is_square(self) -> bool: + """ + Check if the policy is a square partitioning. + + Returns: + bool: True if the policy is square, False otherwise. + """ + return self == GemmWarpPolicy.Square + + def is_full_row(self) -> bool: + """ + Check if the policy is a full row partitioning. + + Returns: + bool: True if the policy is full row, False otherwise. + """ + return self == GemmWarpPolicy.FullRow + + def is_full_col(self) -> bool: + """ + Check if the policy is a full column partitioning. + + Returns: + bool: True if the policy is full column, False otherwise. + """ + return self == GemmWarpPolicy.FullCol + + @staticmethod + def to_prime_factors(num): + """ + Compute the prime factorization of a given number. + + Args: + num (int): The number to factorize. + + Returns: + list: A list of prime factors of the number. + """ + factors = [] + i = 2 + # Find all prime factors up to the square root of the number. + while i * i <= num: + while num % i == 0: # Check divisibility by `i`. + factors.append(i) + num //= i + i += 1 + # If the remaining number is greater than 1, it's a prime factor. + if num > 1: + factors.append(num) + return factors + + def compute_warp_partition(self, M, N, num_warps): + """ + Compute the warp partition (m_warp, n_warp) based on the given policy. + + Args: + M (int): The number of rows in the GEMM workload. + N (int): The number of columns in the GEMM workload. + num_warps (int): The total number of warps available. + + Returns: + tuple: A tuple (m_warp, n_warp) representing the partitioning of warps. + + Raises: + ValueError: If the policy is invalid or the partitioning fails. + AssertionError: If M or N is not divisible by the required factor for FullRow or FullCol policies. + """ + m_warp = 1 # Initial warp count for rows. + n_warp = 1 # Initial warp count for columns. + + if self.is_full_row(): + # FullRow policy: Allocate all warps to rows. + m_warp = num_warps + n_warp = 1 + + # If M cannot be evenly divided by m_warp*16, try to split remaining warps to N + if M % (m_warp * 16) != 0: + # Calculate how many warps we can use for M + max_m_warps = M // 16 + m_warp = max_m_warps + # Use remaining warps for N + n_warp = num_warps // m_warp + if n_warp == 0: + n_warp = 1 + + elif self.is_full_col(): + # FullCol policy: Allocate all warps to columns. + m_warp = 1 + n_warp = num_warps + + # If N cannot be evenly divided by n_warp*8, try to split remaining warps to M + if N % (n_warp * 8) != 0: + # Calculate how many warps we can use for N + max_n_warps = N // 8 + n_warp = max_n_warps + # Use remaining warps for M + m_warp = num_warps // n_warp + if m_warp == 0: + m_warp = 1 + + elif self.is_square(): + # First calculate the maximum possible warps for each dimension + max_m_warps = M // 16 # Each warp needs at least 16 elements in M + max_n_warps = N // 8 # Each warp needs at least 8 elements in N + + # Calculate the ideal ratio of M/N warps based on the matrix dimensions + ideal_ratio = 1.0 + if N > 0: + ideal_ratio = float(M) / N + + # Start with a balanced initial guess + m_warp = 1 + n_warp = 1 + + # Try to find the best balanced partition + best_m = 1 + best_n = 1 + best_balance = float("inf") + + # Try all possible combinations that satisfy the constraints + for m in range(1, min(max_m_warps, num_warps) + 1): + n = num_warps // m + if n > max_n_warps: + continue + if m * n != num_warps: + continue + + # Calculate how balanced this partition is + m_per_warp = float(M) / (m * 16) + n_per_warp = float(N) / (n * 8) + balance = abs(m_per_warp / n_per_warp - ideal_ratio) + + if balance < best_balance: + best_balance = balance + best_m = m + best_n = n + + m_warp = best_m + n_warp = best_n + + else: + # Raise an error for unknown policies. + raise ValueError(f"Unknown GemmWarpPolicy: {self}") + + return m_warp, n_warp + + @classmethod + def from_warp_partition(cls, m_warp: int, n_warp: int) -> GemmWarpPolicy: + """ + Determine the warp policy based on the given warp partitioning. + + Args: + m_warp (int): Number of warps in the row dimension + n_warp (int): Number of warps in the column dimension + + Returns: + GemmWarpPolicy: The corresponding warp policy + + Examples: + >>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows + GemmWarpPolicy.FullRow + >>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns + GemmWarpPolicy.FullCol + >>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution + GemmWarpPolicy.Square + """ + if n_warp == 1 and m_warp > 1: + return cls.FullRow + elif m_warp == 1 and n_warp > 1: + return cls.FullCol + else: + return cls.Square diff --git a/tilelang/original/tilelang/tileop/gemm/__init__.py b/tilelang/original/tilelang/tileop/gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d618f1a44de4aacadbdaadd232b7710e8fad9a95 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/__init__.py @@ -0,0 +1,199 @@ +from enum import IntEnum +from tilelang import tvm as tvm +from tvm import tir +from tvm.target import Target +from tvm.ir.base import Node +from tvm.ir import Range +from tvm.runtime import Scriptable +import tvm_ffi +from .gemm_mma import GemmMMA +from .gemm_mma_sm70 import GemmMMASm70 +from .gemm_wgmma import GemmWGMMA +from .gemm_tcgen05 import GemmTCGEN5 +from .gemm_mfma import GemmMFMA +from .gemm_mmac import GemmMMAC +from tilelang import _ffi_api +from tilelang.utils.target import target_is_volta + + +@tvm_ffi.register_global_func("tl.gemm_py.infer_layout") +def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range): + thread_nums = thread_bounds.extent + return gemm_py.infer_layout(target, thread_nums) + + +@tvm_ffi.register_global_func("tl.gemm_py.lower") +def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, thread_var: tir.Var): + thread_nums = thread_bounds.extent + stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) + return stmt + + +# TODO(lei): support Volta and WMMA? +# same definition with src/op/gemm_py.h +class GemmInst(IntEnum): + MMA = 0 + WGMMA = 1 + TCGEN5MMA = 2 + MFMA = 3 + MMAC = 4 + + def is_mma(self) -> bool: + return self == GemmInst.MMA + + def is_wgmma(self) -> bool: + return self == GemmInst.WGMMA + + def is_tcgen5mma(self) -> bool: + return self == GemmInst.TCGEN5MMA + + def is_mfma(self) -> bool: + return self == GemmInst.MFMA + + def is_mmac(self) -> bool: + return self == GemmInst.MMAC + + def __repr__(self) -> str: + return self.name + + +@tvm_ffi.register_object("tl.GemmPy") +class GemmPy(Node, Scriptable): + # FFI fields (LLVM/MLIR-style lowerCamel via reflection): + # a, b, c, aPtr, bPtr, cPtr, m, n, k, transA, transB, + # strideA, strideB, offsetA, offsetB, clearAccum, kPack, wgWait, policy + # + # Backward-compat alias properties are provided below to support old names. + + # Backward-compat alias properties (old API → new FFI fields) + @property + def A(self): + return self.a + + @property + def B(self): + return self.b + + @property + def C(self): + return self.c + + @property + def APtr(self): + return self.aPtr + + @property + def BPtr(self): + return self.bPtr + + @property + def CPtr(self): + return self.cPtr + + @property + def M(self): + return self.m + + @property + def N(self): + return self.n + + @property + def K(self): + return self.k + + @property + def trans_A(self): + return self.transA + + @property + def trans_B(self): + return self.transB + + @property + def stride_A(self): + return self.strideA + + @property + def stride_B(self): + return self.strideB + + @property + def offset_A(self): + return self.offsetA + + @property + def offset_B(self): + return self.offsetB + + @property + def clear_accum(self): + return self.clearAccum + + @property + def k_pack(self): + return self.kPack + + @property + def wg_wait(self): + return self.wgWait + + def infer_layout(self, target: Target, thread_nums: int): + """Infer the layout for the GEMM operation based on target architecture.""" + gemm_inst = self._select_gemm_instruction(thread_nums, target) + impl_class = self._get_implementation_class(gemm_inst, target) + return impl_class(self).infer_layout(target, thread_nums) + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + """Lower the GEMM operation to TIR statements based on target architecture.""" + gemm_inst = self._select_gemm_instruction(thread_nums, target) + impl_class = self._get_implementation_class(gemm_inst, target) + return impl_class(self).lower(layout_map, target, thread_nums, thread_var) + + def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst: + """Select the appropriate GEMM instruction based on target and thread configuration. + + The selection logic follows this priority: + 1. WGMMA for Hopper architecture with sufficient matrix size and warp count + 2. MFMA for CDNA (AMD) architecture + 3. MMA for CUDA architecture + 4. Fallback to MMA for other cases + + Args: + thread_nums: Number of threads in the block + target: Target architecture + + Returns: + GemmInst: The selected GEMM instruction type + """ + return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target)) + + def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): + """Get the appropriate implementation class for the given GEMM instruction. + + Args: + gemm_inst: The selected GEMM instruction type + + Returns: + The implementation class for the instruction type + + Raises: + NotImplementedError: If the instruction type is not supported + ValueError: If the instruction type is unknown + """ + if gemm_inst.is_mma(): + if target_is_volta(target): + return GemmMMASm70 + return GemmMMA + elif gemm_inst.is_wgmma(): + return GemmWGMMA + elif gemm_inst.is_tcgen5mma(): + return GemmTCGEN5 + elif gemm_inst.is_mmac(): + return GemmMMAC + elif gemm_inst.is_mfma(): + return GemmMFMA + elif gemm_inst.is_tcgen5mma(): + raise NotImplementedError("TCGEN5MMA is not implemented") + else: + raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_base.py b/tilelang/original/tilelang/tileop/gemm/gemm_base.py new file mode 100644 index 0000000000000000000000000000000000000000..7d31ae46d76884bfde14ebb9d0e715980682ecb0 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_base.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.utils.language import is_shared, is_fragment +from tilelang.tileop.base import GemmWarpPolicy +from tvm.ir.base import Node +from tvm.ir import PrimExpr + + +@dataclass +class GemmBase: + gemm_node: Node + + def infer_layout(self, target: Target, thread_nums: int): + raise NotImplementedError("infer_layout is not implemented") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + raise NotImplementedError("lower is not implemented") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) + + @property + def M(self) -> int: + return getattr(self.gemm_node, "m", None) + + @property + def N(self) -> int: + return getattr(self.gemm_node, "n", None) + + @property + def K(self) -> int: + return getattr(self.gemm_node, "k", None) + + @property + def trans_A(self) -> bool: + return getattr(self.gemm_node, "transA", None) + + @property + def trans_B(self) -> bool: + return getattr(self.gemm_node, "transB", None) + + @property + def in_dtype(self) -> str: + assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" + return self.A.dtype + + @property + def accum_dtype(self) -> str: + return self.C.dtype + + @property + def chunk(self) -> int: + return self.A.shape[-2] if self.trans_A else self.A.shape[-1] + + @property + def A(self) -> tir.Buffer: + return getattr(self.gemm_node, "a", None) + + @property + def B(self) -> tir.Buffer: + return getattr(self.gemm_node, "b", None) + + @property + def C(self) -> tir.Buffer: + return getattr(self.gemm_node, "c", None) + + @property + def ARegion(self): + return getattr(self.gemm_node, "aRegion", None) + + @property + def BRegion(self): + return getattr(self.gemm_node, "bRegion", None) + + @property + def CRegion(self): + return getattr(self.gemm_node, "cRegion", None) + + @property + def stride_A(self) -> int: + return getattr(self.gemm_node, "strideA", None) + + @property + def stride_B(self) -> int: + return getattr(self.gemm_node, "strideB", None) + + @property + def offset_A(self) -> int: + return getattr(self.gemm_node, "offsetA", None) + + @property + def offset_B(self) -> int: + return getattr(self.gemm_node, "offsetB", None) + + @property + def clear_accum(self) -> PrimExpr: + return getattr(self.gemm_node, "clearAccum", None) + + @property + def k_pack(self) -> int: + return getattr(self.gemm_node, "kPack", None) + + @property + def wg_wait(self) -> int: + return getattr(self.gemm_node, "wgWait", 0) + + @property + def policy(self) -> GemmWarpPolicy: + return getattr(self.gemm_node, "policy", None) + + @property + def mbarptr(self) -> PrimExpr: + return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, T.uint32)) + + @property + def mbar(self) -> tir.Buffer: + return getattr(self.gemm_node, "mbar", None) + + @property + def C_coords(self): + coords = getattr(self.gemm_node, "cCoords", None) + if coords is None or len(coords) == 0: + zero = tvm.tir.const(0, T.int32) + return [zero, zero] + return [coords[i] for i in range(len(coords))] + + def get_region_base_offsets(self, region): + """ + Get the base offset (start index) for each dimension from a BufferRegion. + + For example, if region is A_shared[ko % 2, 0:128, 0:64], + this returns [ko % 2, 0, 0] + + Args: + region: BufferRegion object + + Returns: + List of PrimExpr representing the base offset for each dimension + """ + if region is None: + return [] + return [r.min for r in region.region] + + @property + def A_base_offsets(self): + """Get base offsets for each dimension of A region""" + return self.get_region_base_offsets(self.ARegion) + + @property + def B_base_offsets(self): + """Get base offsets for each dimension of B region""" + return self.get_region_base_offsets(self.BRegion) + + @property + def C_base_offsets(self): + """Get base offsets for each dimension of C region""" + return self.get_region_base_offsets(self.CRegion) diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/original/tilelang/tileop/gemm/gemm_mfma.py new file mode 100644 index 0000000000000000000000000000000000000000..d827d8a2a3fa9ad901dadf214e576f942619f0af --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_mfma.py @@ -0,0 +1,227 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mfma_macro_generator import ( + MatrixCoreIntrinEmitter, +) +from tilelang.utils.language import is_shared, is_fragment, is_full_region +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMFMA(GemmBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + k_pack=self.k_pack, + ) + + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), + self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + k_pack=self.k_pack, + ) + + in_dtype = self.in_dtype + warp_rows = mfma_emitter.warp_rows + warp_cols = mfma_emitter.warp_cols + local_size_a = mfma_emitter.local_size_a + local_size_b = mfma_emitter.local_size_b + block_K = mfma_emitter.chunk + micro_size_k = mfma_emitter.micro_size_k + # Use region for shared-memory operands if available + # We use region for memory input to support strided gemm + # T.gemm(A_shared[0:128, :], B_shared, C_local) + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + A_buf = A_region.buffer + B_buf = B_region.buffer + C_buf = C_region.buffer + + clear_accum = self.clear_accum + + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + + assert is_full_region(C_region), "Fragment output C must be a full region" + + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype) + + if clear_accum: + T.clear(C_buf) + + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_buf, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + assert is_full_region(A_region), "Fragment input A must be a full region" + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_buf, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + assert is_full_region(A_region), "Fragment input A must be a full region" + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Perform Matrix Multiplication + mfma_emitter.mfma(A_buf, B_buf, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_mma.py b/tilelang/original/tilelang/tileop/gemm/gemm_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..b15173483813aa28f0a72d74260fba4b23dab3e7 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_mma.py @@ -0,0 +1,222 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.utils.language import is_shared, is_fragment, is_full_region +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMMA(GemmBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + block_K = mma_emitter.chunk + micro_size_k = mma_emitter.micro_size_k + # We use region for memory input to support strided gemm + # T.gemm(A_shared[0:128, :], B_shared, C_local) + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + A_buf = A_region.buffer + B_buf = B_region.buffer + C_buf = C_region.buffer + + clear_accum = self.clear_accum + + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + + assert is_full_region(C_region), "Fragment output C must be a full region" + + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + if clear_accum: + T.clear(C_buf) + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_buf, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + assert is_full_region(A_region), "Fragment input A must be a full region" + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // micro_size_k)): + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_buf, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + assert is_full_region(A_region), "Fragment input A must be a full region" + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_rrr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + + for ki in T.serial(0, (block_K // micro_size_k)): + # Perform Matrix Multiplication + mma_emitter.mma(A_buf, B_buf, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rrr, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_mma_sm70.py b/tilelang/original/tilelang/tileop/gemm/gemm_mma_sm70.py new file mode 100644 index 0000000000000000000000000000000000000000..52a4bf3262f0054be28158a0d3c0db7863512ddf --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_mma_sm70.py @@ -0,0 +1,166 @@ +# for Volta GPUs, which use legacy MMA instructions +from .gemm_base import GemmBase +from tilelang.layout import make_volta_swizzled_layout +from tilelang.intrinsics.mma_sm70_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.utils.language import is_shared, is_fragment, is_full_region +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMMASm70(GemmBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + if self.is_gemm_ss(): + return { + self.A: make_volta_swizzled_layout(self.A, is_a=True, k_inner=a_is_k_major), + self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + block_K = mma_emitter.chunk + micro_size_k = mma_emitter.micro_size_k + # Use region for shared-memory operands when applicable + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + A_buf = A_region.buffer + C_buf = C_region.buffer + + clear_accum = self.clear_accum + + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + + assert is_full_region(C_region), "Fragment output C must be a full region" + + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + if clear_accum: + T.clear(C_buf) + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_rs(): + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + if clear_accum: + T.clear(C_buf) + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_buf, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_mmac.py b/tilelang/original/tilelang/tileop/gemm/gemm_mmac.py new file mode 100644 index 0000000000000000000000000000000000000000..4b560f64a7474724ad32b042b91dd0da382e1516 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_mmac.py @@ -0,0 +1,227 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mmac_macro_generator import ( + MatrixCoreIntrinEmitter, +) +from tilelang.utils.language import is_shared, is_fragment, is_full_region +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMMAC(GemmBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mmac_emitter = MatrixCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + k_pack=self.k_pack, + ) + + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mmac_emitter.make_mmac_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mmac_emitter.make_mmac_load_layout(self.B, matrix="B"), + self.C: mmac_emitter.make_mmac_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mmac_emitter.make_mmac_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mmac_emitter.make_mmac_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mmac_emitter.make_mmac_load_layout(self.A, matrix="A"), + self.B: mmac_emitter.make_mmac_load_layout(self.B, matrix="B"), + self.C: mmac_emitter.make_mmac_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mmac_emitter = MatrixCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + k_pack=self.k_pack, + ) + + in_dtype = self.in_dtype + warp_rows = mmac_emitter.warp_rows + warp_cols = mmac_emitter.warp_cols + local_size_a = mmac_emitter.local_size_a + local_size_b = mmac_emitter.local_size_b + block_K = mmac_emitter.chunk + micro_size_k = mmac_emitter.micro_size_k + # Use region for shared-memory operands if available + # We use region for memory input to support strided gemm + # T.gemm(A_shared[0:128, :], B_shared, C_local) + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + A_buf = A_region.buffer + B_buf = B_region.buffer + C_buf = C_region.buffer + + clear_accum = self.clear_accum + + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + + assert is_full_region(C_region), "Fragment output C must be a full region" + + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mmac ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Load A into fragment + mmac_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Load B into fragment + mmac_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mmac_emitter.mmac(A_local, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mmac ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype) + + if clear_accum: + T.clear(C_buf) + + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Load A into fragment + mmac_emitter.ldmatrix_a( + A_local, + A_region, + ki, + ) + + # Perform Matrix Multiplication + mmac_emitter.mmac(A_local, B_buf, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + assert is_full_region(A_region), "Fragment input A must be a full region" + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mmac ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Load B into fragment + mmac_emitter.ldmatrix_b( + B_local, + B_region, + ki, + ) + + # Perform Matrix Multiplication + mmac_emitter.mmac(A_buf, B_local, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + assert is_full_region(A_region), "Fragment input A must be a full region" + assert is_full_region(B_region), "Fragment input B must be a full region" + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mmac ops, + accumulating into C_local. + """ + + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + # Perform Matrix Multiplication + mmac_emitter.mmac(A_buf, B_buf, C_buf, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/original/tilelang/tileop/gemm/gemm_tcgen05.py new file mode 100644 index 0000000000000000000000000000000000000000..de3e72143c098a10dafaa2b2df1b440c8b56bba2 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_tcgen05.py @@ -0,0 +1,114 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_tcgen05mma_swizzled_layout +from tilelang.intrinsics.tcgen05_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang import language as T +from tilelang.transform.simplify import _Simplify +from tvm import tir +from tvm.target import Target + +_FLOAT8_DTYPES = { + "float8_e4m3", + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fn", + "float8_e5m2fnuz", +} + + +class GemmTCGEN5(GemmBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + + if self.is_gemm_ss(): + a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp + b_continuity = self.K if b_is_k_major else self.N // n_warp + + return { + # WGMMA does not support padding + self.A: make_tcgen05mma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: make_tcgen05mma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + # No special swizzle requirement; rely on existing layout. + return {} + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B]) + + if not self.is_gemm_ss(): + raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got A scope {self.A.scope()}, B scope {self.B.scope()}") + + atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K) + + if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: + raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") + if self.B.scope() not in {"shared", "shared.dyn"}: + raise ValueError(f"Unsupported B scope for TCGEN5MMA: {self.B.scope()}") + if self.C.scope() != "shared.tmem": + raise ValueError(f"TCGEN5MMA expects C in shared.tmem, got {self.C.scope()}") + if self.wg_wait != -1: + raise ValueError("TCGEN5MMA currently requires wg_wait == -1") + + mbar = self.mbar + if mbar == 0: + raise ValueError("TCGEN5MMA requires a valid mbarrier") + + mbarptr = mbar.access_ptr("rw") + + C_coords = self.C_coords + if len(C_coords) != 2: + raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") + + accum_dtype = str(self.C.dtype) + if accum_dtype not in [str(T.float32), str(T.float16)]: + raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") + + A_shared = self.ARegion + B_shared = self.BRegion + C_local = self.C + clear_accum = self.clear_accum + + @T.prim_func + def _gemm_ss() -> None: + if thread_var // 32 == 0: + mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbarptr, clear_accum) + + return _Simplify(_gemm_ss, inline_let=True) diff --git a/tilelang/original/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/original/tilelang/tileop/gemm/gemm_wgmma.py new file mode 100644 index 0000000000000000000000000000000000000000..038aa2cd66692bb50386ccec669083c615882420 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm/gemm_wgmma.py @@ -0,0 +1,136 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_wgmma_swizzled_layout +from tilelang.intrinsics.wgmma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.utils.language import is_shared, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmWGMMA(GemmBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + + if self.is_gemm_ss(): + a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp + b_continuity = self.K if b_is_k_major else self.N // n_warp + + return { + # WGMMA does not support padding + self.A: make_wgmma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) + + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B]) + + # Get base offsets from regions + # All dimensions may have offsets, including the matrix dimensions + # However, for WGMMA, we pass the Buffer directly and handle offsets + # through proper indexing in the access_ptr call or buffer slicing + + # We use region for memory input to support strided gemm + # T.gemm(A_shared[0:128, :], B_shared, C_local) + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + clear_accum = self.clear_accum + wg_wait = self.wg_wait + + if self.is_gemm_ss(): + # For WGMMA, we need to handle buffer region offsets + # If there are offsets, we create a BufferLoad inside the prim_func + # to properly generate offset access + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + # Perform Matrix Multiplication with offset consideration + mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_rs(): + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/original/tilelang/tileop/gemm_sp/__init__.py b/tilelang/original/tilelang/tileop/gemm_sp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d75657ec6c49333d184581880e861fed468c114 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm_sp/__init__.py @@ -0,0 +1,69 @@ +from tilelang import tvm as tvm +from tvm import tir +from tilelang.utils.target import ( + target_is_cuda, +) +from tvm.target import Target +from tvm.ir.base import Node +from tvm.ir import Range +from tvm.runtime import Scriptable +import tvm_ffi +from tilelang.tileop.base import GemmWarpPolicy +from .gemm_sp_mma import GemmSPMMA + + +@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout") +def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range): + thread_nums = thread_bounds.extent + return gemm_sp_py.infer_layout(target, thread_nums) + + +@tvm_ffi.register_global_func("tl.gemm_sp_py.lower") +def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var): + thread_nums = thread_bounds.extent + stmt = gemm_sp_py.lower(target, thread_nums, thread_var) + return stmt + + +@tvm_ffi.register_object("tl.GemmSPPy") +class GemmSPPy(Node, Scriptable): + A: tir.Buffer + E: tir.Buffer + B: tir.Buffer + C: tir.Buffer + + APtr: tir.PrimExpr + EPtr: tir.PrimExpr + BPtr: tir.PrimExpr + CPtr: tir.PrimExpr + + M: int + N: int + K: int + + trans_A: bool + trans_B: bool + + stride_A: int + stride_B: int + offset_A: int + offset_B: int + clear_accum: bool + k_pack: int + wg_wait: int + policy: GemmWarpPolicy + + def infer_layout(self, target: Target, thread_nums: int): + if target_is_cuda(target): + # TODO(lei): Support more cuda architectures, now mma only + return GemmSPMMA(self).infer_layout(target, thread_nums) + else: + raise ValueError(f"Unsupported target: {target}") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + if target_is_cuda(target): + # TODO(lei): Support more cuda architectures, now mma only + # Now only implement ssr layout + return GemmSPMMA(self).lower(target, thread_nums, thread_var) + else: + raise ValueError(f"Unsupported target: {target}") diff --git a/tilelang/original/tilelang/tileop/gemm_sp/gemm_sp_base.py b/tilelang/original/tilelang/tileop/gemm_sp/gemm_sp_base.py new file mode 100644 index 0000000000000000000000000000000000000000..8226a066417d5f78c1d4c9b70e35a58a7d382568 --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm_sp/gemm_sp_base.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang.utils.language import is_shared, is_fragment +from tilelang.tileop.base import GemmWarpPolicy +from tvm.ir.base import Node + + +@dataclass +class GemmSPBase: + gemm_sp_node: Node + + def infer_layout(self, target: Target, thread_nums: int): + raise NotImplementedError("infer_layout is not implemented") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + raise NotImplementedError("lower is not implemented") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) + + @property + def M(self) -> int: + return self.gemm_sp_node.M + + @property + def N(self) -> int: + return self.gemm_sp_node.N + + @property + def K(self) -> int: + return self.gemm_sp_node.K + + @property + def trans_A(self) -> bool: + return self.gemm_sp_node.trans_A + + @property + def trans_B(self) -> bool: + return self.gemm_sp_node.trans_B + + @property + def trans_E(self) -> bool: + return self.gemm_sp_node.trans_E + + @property + def e_dtype(self) -> str: + return self.E.dtype + + @property + def in_dtype(self) -> str: + assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" + return self.A.dtype + + @property + def accum_dtype(self) -> str: + return self.C.dtype + + @property + def A(self) -> tir.Buffer: + return self.gemm_sp_node.A + + @property + def E(self) -> tir.Buffer: + return self.gemm_sp_node.E + + @property + def B(self) -> tir.Buffer: + return self.gemm_sp_node.B + + @property + def C(self) -> tir.Buffer: + return self.gemm_sp_node.C + + @property + def ARegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.ARegion + + @property + def ERegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.ERegion + + @property + def BRegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.BRegion + + @property + def CRegion(self) -> tir.PrimExpr: + return self.gemm_sp_node.CRegion + + @property + def stride_A(self) -> int: + return self.gemm_sp_node.stride_A + + @property + def stride_B(self) -> int: + return self.gemm_sp_node.stride_B + + @property + def offset_A(self) -> int: + return self.gemm_sp_node.offset_A + + @property + def offset_B(self) -> int: + return self.gemm_sp_node.offset_B + + @property + def clear_accum(self) -> bool: + return self.gemm_sp_node.clear_accum + + @property + def k_pack(self) -> int: + return self.gemm_sp_node.k_pack + + @property + def wg_wait(self) -> int: + return self.gemm_sp_node.wg_wait + + @property + def policy(self) -> GemmWarpPolicy: + return self.gemm_sp_node.policy diff --git a/tilelang/original/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/original/tilelang/tileop/gemm_sp/gemm_sp_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..76a0d4a9ed8a3800e0a2017e7ee3a9b7995af49e --- /dev/null +++ b/tilelang/original/tilelang/tileop/gemm_sp/gemm_sp_mma.py @@ -0,0 +1,241 @@ +from .gemm_sp_base import GemmSPBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.utils.language import is_shared, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmSPMMA(GemmSPBase): + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = SparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + e_dtype=self.e_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + ) + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = SparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + e_dtype=self.e_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols + local_size_a = mma_emitter.local_size_a + local_size_e = mma_emitter.local_size_e + local_size_b = mma_emitter.local_size_b + micro_size_k = mma_emitter.micro_size_k + A_shared = self.A + E_shared = self.E + B_shared = self.B + C_local = self.C + assert micro_size_k <= self.K, f"K dimension {self.K} should be >= micro size k {micro_size_k}" + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + B_local = self.B + + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + A_local = self.A + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + A_local = self.A + B_local = self.B + + @T.prim_func + def _gemm_rrr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) + + for ki in T.serial(0, (self.K // micro_size_k)): + # Load E into fragment + mma_emitter.ldmatrix_e( + E_local, + E_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rrr, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/original/tilelang/tools/Analyzer.py b/tilelang/original/tilelang/tools/Analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..3af5222f29fb24cb2907d43da659a3efa28e7223 --- /dev/null +++ b/tilelang/original/tilelang/tools/Analyzer.py @@ -0,0 +1,218 @@ +from __future__ import annotations +import numpy as np +from dataclasses import dataclass +from tilelang import tvm +from tvm.tir.stmt_functor import ir_transform +import logging + +# Configuration for different hardware architectures. +# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count) +ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)} + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class AnalysisResult: + """ + A data class to store the results of the analysis. + Attributes: + total_flops: Total floating-point operations. + total_global_bytes: Total bytes transferred to/from global memory. + estimated_time: Estimated execution time (seconds). + tflops: Achieved TFLOPS (trillions of FLOPs per second). + bandwidth_GBps: Achieved memory bandwidth in GB/s. + """ + + total_flops: int + total_global_bytes: int + estimated_time: float + expected_tflops: float + expected_bandwidth_GBps: float + + +class Analyzer: + """ + A class to analyze the performance of a TVM IR module. + It calculates metrics such as FLOPs, memory bandwidth, and estimated execution time. + """ + + def __init__(self, fn, device): + """ + Initialize the Analyzer. + Args: + fn: A TVM IRModule or PrimFunc to analyze. + device: The target device information. + """ + if isinstance(fn, tvm.tir.function.PrimFunc): + self.fn = tvm.IRModule({"main": fn}) + else: + self.fn = fn + self.device = device + self.total_flops = 0 # Total floating-point operations + self.total_global_bytes = 0 # Total global memory bytes + self.block_counts = {"blockIdx.x": 1, "blockIdx.y": 1} # Block dimensions + self.loop_stack = [] # Stack to track nested loops + self.global_buffers = set() # Set of global memory buffers + + def _analyze_copy(self, call): + """ + Analyze memory copy operations (e.g., tl.copy). + Args: + call: A TVM Call node representing the copy operation. + """ + src_buffer = call.args[0].args[0].buffer + dst_buffer = call.args[1].args[0].buffer + + # Determine if the source or destination is a global buffer + if src_buffer in self.global_buffers: + buffer_region = call.args[0] + elif dst_buffer in self.global_buffers: + buffer_region = call.args[1] + else: + return + + # Calculate the number of elements being copied + elements = 1 + for r in range(2, len(buffer_region.args)): + elements *= buffer_region.args[r] + dtype_size = np.dtype(buffer_region.args[0].buffer.dtype).itemsize # Size of the data type + bytes_transferred = elements * dtype_size # Total bytes transferred + + # Account for loop and block dimensions + loop_product = 1 + for extent in self.loop_stack: + loop_product *= extent.value if hasattr(extent, "value") else extent + total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"] + total_bytes = bytes_transferred * loop_product * total_blocks + self.total_global_bytes += total_bytes + + def _analyze_gemm(self, call): + """ + Analyze matrix multiplication (GEMM) operations (e.g., tl.gemm). + Args: + call: A TVM Call node representing the GEMM operation. + """ + M = call.args[5].value + N = call.args[6].value + K = call.args[7].value + flops_per_call = 2 * M * N * K # FLOPs for one GEMM operation + + # Account for loop and block dimensions + loop_product = 1 + for extent in self.loop_stack: + loop_product *= extent.value if hasattr(extent, "value") else extent + total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"] + self.total_flops += flops_per_call * loop_product * total_blocks + + def ir_pass(self): + """ + Traverse and transform the IR module to extract performance-related information. + Returns: + self: The Analyzer instance. + """ + + def _ftransform(f, mod, ctx): + # Initialize the set of global buffers + self.global_buffers = set(f.buffer_map.values()) + + def _pre_visit(stmt): + """ + Pre-visit callback for IR nodes. + Args: + stmt: The current IR node being visited. + """ + if isinstance(stmt, tvm.tir.AttrStmt): + # Handle thread extent attributes + if stmt.attr_key == "thread_extent": + iter_var = stmt.node + thread_tag = iter_var.thread_tag + if thread_tag in self.block_counts: + extent = stmt.value.value if hasattr(stmt.value, "value") else stmt.value + self.block_counts[thread_tag] = extent + elif isinstance(stmt, tvm.tir.For): + # Push loop extent onto the stack + self.loop_stack.append(stmt.extent) + elif isinstance(stmt, tvm.tir.Evaluate): + # Handle Evaluate nodes containing calls + value = stmt.value + if isinstance(value, tvm.tir.Call): + if value.op.name == "tl.copy": + self._analyze_copy(value) + elif value.op.name == "tl.gemm": + self._analyze_gemm(value) + return None + + def _post_visit(stmt): + """ + Post-visit callback for IR nodes. + Args: + stmt: The current IR node being visited. + """ + if isinstance(stmt, tvm.tir.For) and self.loop_stack: + self.loop_stack.pop() + return None + + # Use IR transformation to traverse and modify the function body + new_body = ir_transform(f.body, _pre_visit, _post_visit) + return f.with_body(new_body) + + # Apply the custom PrimFunc pass + tvm.tir.transform.prim_func_pass(_ftransform, opt_level=0)(self.fn) + return self + + def calculate(self) -> AnalysisResult: + """ + Calculate performance metrics based on the analysis. + Returns: + AnalysisResult: The calculated performance metrics. + """ + + def get_peak_tflops(device) -> float | None: + """ + Get the peak TFLOPS for the target device. + Args: + device: The target device information. + Returns: + float: The peak TFLOPS. + """ + arch_key = device.compute_capability[:2] + if arch_key not in ARCH_CONFIGS: + logger.info(f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None") + return None + + cores_per_sm, default_clock, flops_per_cycle, compute_max_core = ARCH_CONFIGS[arch_key] + total_cores = compute_max_core * cores_per_sm + tflops = (total_cores * default_clock * flops_per_cycle) / 1e3 + return round(tflops, 1) + + # Calculate memory bandwidth and peak TFLOPS + bandwidth_GBps = self.device.bandwidth[1] / 1000 + peak_tflops = get_peak_tflops(self.device) + + # Estimate memory and compute times + mem_time = self.total_global_bytes / (bandwidth_GBps * 1e9) + compute_time = self.total_flops / (peak_tflops * 1e12) if peak_tflops else None + estimated_time = max(mem_time, compute_time) if peak_tflops else mem_time + + # Return the analysis results + return AnalysisResult( + total_flops=self.total_flops, + total_global_bytes=self.total_global_bytes, + estimated_time=estimated_time, + expected_tflops=peak_tflops, + expected_bandwidth_GBps=bandwidth_GBps, + ) + + @classmethod + def analysis(cls, fn, device): + """ + Perform a full analysis of the given IR module or PrimFunc. + Args: + fn: A TVM IRModule or PrimFunc to analyze. + device: The target device information. + Returns: + AnalysisResult: The calculated performance metrics. + """ + return cls(fn, device).ir_pass().calculate() diff --git a/tilelang/original/tilelang/tools/__init__.py b/tilelang/original/tilelang/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8bde51408bea4c73f19e2a45ab2fb4e33f345b --- /dev/null +++ b/tilelang/original/tilelang/tools/__init__.py @@ -0,0 +1,2 @@ +from .plot_layout import plot_layout # noqa: F401 +from .Analyzer import * diff --git a/tilelang/original/tilelang/tools/plot_layout.py b/tilelang/original/tilelang/tools/plot_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..299c3e86b6325c84e43c575c6dc13729725b736b --- /dev/null +++ b/tilelang/original/tilelang/tools/plot_layout.py @@ -0,0 +1,214 @@ +from __future__ import annotations +import tilelang.language as T + + +def plot_layout( + layout: T.Fragment, + save_directory="./tmp", + name: str = "layout", + colormap: str = "RdPu", + verbose: bool = False, + formats: str | list[str] = "png", +) -> None: + """ + Plot the layout of a buffer. + + Parameters + ---------- + layout : T.Layout + The layout object that describes how indices are mapped. + save_directory : str, optional + The directory where the output images will be saved (default is "./tmp"). + name : str, optional + The base name of the output files (default is "layout"). + colormap : str, optional + The colormap to use for visualization (default is "RdPu"). + verbose : bool, optional + If True, prints additional information about the mapping (default is False). + formats : str | list[str], optional + The formats to save the image in (default is "png"). + Returns + ------- + None + """ + import os + import pathlib + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.patches as patches + + # Get the input shape of the layout and convert it to a list of integers + input_shape = layout.get_input_shape() + input_shape = [int(var) for var in input_shape] + replicate_size = int(layout.replicate_size) + + # Get the total number of threads + num_threads = int(layout.get_thread_size()) + + import itertools + + # Initialize a 2D array to store thread mappings + thread_map = np.empty(input_shape, dtype=object) + for idx in np.ndindex(thread_map.shape): + thread_map[idx] = [] + + # Initialize a 2D array to store value mappings + value_map = np.zeros(input_shape, dtype=object) + for idx in np.ndindex(value_map.shape): + value_map[idx] = [] + + # Iterate over all possible indices in the input shape + for i in range(replicate_size): + for idx in itertools.product(*[range(dim) for dim in input_shape]): + index = list(idx) + # If replication is enabled, adjust the index + if replicate_size > 1: + index.insert(0, i) + # Map the index to a thread ID + thread_id = layout.map_forward_thread(index) + assert len(thread_id) == 1 # Ensure a single-thread mapping + thread_map[idx].append(int(thread_id[0])) # Store the thread ID + + # Iterate again to map values + for i in range(replicate_size): + for idx in itertools.product(*[range(dim) for dim in input_shape]): + index = list(idx) + if replicate_size > 1: + index.insert(0, i) + thread_id = layout.map_forward_thread(index) + value_id = layout.map_forward_index(index) + assert len(value_id) == 1 # Ensure a single-value mapping + value_map[idx].append(int(value_id[0])) # Store the value ID + + # Load the colormap with twice as many colors as the number of threads + cmap = plt.get_cmap(colormap, num_threads * 2 // replicate_size) + + # Generate a list of colors based on the colormap + raw_colors = [cmap(i) for i in range(num_threads)] + colors = raw_colors.copy() + + # Show the distribution of registers in each thread of a warp. + warp_size = 32 + # Warn if the number of threads is less than the warp size + if num_threads < warp_size: + import warnings + + warnings.warn( + f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). " + f"For the best viewing experience, it is recommended to have at least {warp_size} threads.", + UserWarning, + stacklevel=2, + ) + spectral_camp = plt.get_cmap("hsv", warp_size * 6) + + for i in range(min(warp_size, num_threads)): + colors[i] = spectral_camp(i * 6) + + # Determine the number of rows and columns in the input shape + nrows, ncols = input_shape + # Adjust figure size to maintain square cells + cell_size = 1 # Base size for each cell + plt.figure(figsize=(cell_size * ncols, cell_size * nrows)) # Set the figure size proportionally + ax = plt.gca() # Get the current axis + font_size = 24 # Set font size for text annotation + + # Iterate through each row and column + for i in range(nrows): + for j in range(ncols): + thread_ids = thread_map[i, j] # Get the thread ID + local_ids = value_map[i, j] # Get the value ID + if verbose: + print(f"thread_map[{i}, {j}] = {thread_ids} value_map[{i}, {j}] = {local_ids}") + + color = colors[thread_ids[0]] # Select color based on thread ID + # Create a rectangle patch for visualization + rect = patches.Rectangle((j, i), 1, 1, linewidth=0.5, edgecolor="black", facecolor=color) + ax.add_patch(rect) # Add the rectangle to the plot + + # Add text annotations inside the rectangles + thread_str = [] + for thread_id in thread_ids: + thread_str.append(f"{thread_id}") + thread_str = "T" + "/".join(thread_str) + local_id = local_ids[0] + # assert local id in local_ids is equal + assert all(local_id == local_id for local_id in local_ids) + + # Calculate thread font size based on string length + thread_fontsize = min(font_size, font_size * (4 / len(thread_str))) + + # Add thread ID text with adjusted font size + ax.text(j + 0.5, i + 0.3, thread_str, ha="center", va="center", color="black", fontsize=thread_fontsize) + # Add local ID text with original font size + ax.text(j + 0.5, i + 0.7, f"L{local_id}", ha="center", va="center", color="black", fontsize=font_size) + + # Add row labels to the left side of the plot + for i in range(nrows): + text = f"row {i}" + ax.text(-0.75, i + 0.5, text, ha="center", va="center", color="black", fontsize=font_size) + + # Add column labels at the top of the plot + for j in range(ncols): + text = f"col {j}" + ax.text(j + 0.5, -0.5, text, ha="center", va="center", color="black", fontsize=font_size, rotation=45) + + # Set the plot limits + ax.set_xlim(0, ncols) + ax.set_ylim(0, nrows) + ax.invert_yaxis() # Invert the y-axis for proper visualization + plt.xticks([]) # Remove x-axis ticks + plt.yticks([]) # Remove y-axis ticks + + # Calculate legend position based on figure size + fig = plt.gcf() + fig_width = fig.get_size_inches()[0] + fig_height = fig.get_size_inches()[1] + legend_x = 1.0 + (0.5 / fig_width) # Adjust x position based on figure width + legend_y = 1.0 + (1.7 / fig_height) # Adjust y position based on figure height + + legend_patches = [patches.Patch(color="black", label="T: Thread ID"), patches.Patch(color="black", label="L: Local ID")] + ax.legend( + handles=legend_patches, + loc="upper right", + fontsize=font_size - 4, + frameon=False, + bbox_to_anchor=(legend_x, legend_y), # Dynamic position + ncols=2, + ) + + # Create the output directory if it does not exist + tmp_directory = pathlib.Path(save_directory) + if not os.path.exists(tmp_directory): + os.makedirs(tmp_directory) + + # Save the figure in multiple formats + plt.tight_layout() + + if isinstance(formats, str): + formats_str = formats.strip().lower() + if formats_str == "all": + formats_list = ["pdf", "png", "svg"] + elif "," in formats_str: + formats_list = [f.strip() for f in formats_str.split(",")] + else: + formats_list = [formats_str] + else: + raise TypeError( + f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." + ) + + # Save the figure + if "pdf" in formats_list: + pdf_path = tmp_directory / f"{name}.pdf" + plt.savefig(pdf_path, bbox_inches="tight") + print(f"Saved pdf format into {pdf_path}") + + if "png" in formats_list: + png_path = tmp_directory / f"{name}.png" + plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) + print(f"Saved png format into {png_path}") + + if "svg" in formats_list: + svg_path = tmp_directory / f"{name}.svg" + plt.savefig(svg_path, bbox_inches="tight", format="svg") + print(f"Saved svg format into {svg_path}") diff --git a/tilelang/original/tilelang/transform/__init__.py b/tilelang/original/tilelang/transform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..697dee2b19da6aac7099f023802b6d658951c5f9 --- /dev/null +++ b/tilelang/original/tilelang/transform/__init__.py @@ -0,0 +1,495 @@ +"""Wrapping transformations.""" +# pylint: disable=invalid-name, unsupported-binary-operation + +from . import _ffi_api +from .simplify import Simplify, simplify_prim_func, LetInline # noqa: F401 +from .pass_config import PassConfigKey # noqa: F401 +from tilelang import tvm as tvm # noqa: F401 +from tvm.ir.transform import PassContext # noqa: F401 +from .add_bufstore_wrapper import AddWrapperForSingleBufStore # noqa: F401 + + +def get_pass_context(): + """Get the current pass context""" + return PassContext.current() + + +def ClusterPlanning(): + """ClusterPlanning + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ClusterPlanning() # type: ignore + + +def PipelinePlanning(): + """infer the fragment/shared memory layout + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.PipelinePlanning() # type: ignore + + +def LayoutInference(): + """LayoutInference + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LayoutInference() # type: ignore + + +def LowerTileOp(): + """LowerTileOp + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerTileOp() # type: ignore + + +def InjectSoftwarePipeline(): + """InjectSoftwarePipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectSoftwarePipeline() # type: ignore + + +def FrontendLegalize(): + """FrontendLegalize + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.FrontendLegalize() # type: ignore + + +def LegalizeNegativeIndex(): + """Legalize negative indices in buffer loads. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizeNegativeIndex() # type: ignore + + +def InjectAssumes(): + """Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(...)) form + (tvm builtin assume call) to AttrNode form. + + Returns: + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectAssumes() + + +def LowerHopperIntrin(): + """LowerHopperIntrin + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore + + +def WarpSpecializedPipeline(): + """WarpSpecializedPipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.WarpSpecializedPipeline() # type: ignore + + +def RewriteWgmmaSync(): + """RewriteWgmmaSync + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RewriteWgmmaSync() # type: ignore + + +def ThreadSync(storage_scope: str): + """Insert sync between parallel read/write of shared buffers. + + Parameters + ---------- + storage_scope: str + The target storage scope. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ThreadSync(storage_scope) # type: ignore + + +def ThreadPartialSync(storage_scope: str): + """Insert partial sync. + + Parameters + ---------- + storage_scope: str + The target storage scope. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ThreadPartialSync(storage_scope) # type: ignore + + +def IfStmtBinding(): + """IfStmtBinding + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.IfStmtBinding() # type: ignore + + +def MergeIfStmt(): + """MergeIfStmt + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeIfStmt() # type: ignore + + +def MultiVersionBuffer(): + """WarpSpecializedPipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MultiVersionBuffer() # type: ignore + + +def WarpSpecialized(): + """WarpSpecializedPipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.WarpSpecialized() # type: ignore + + +def AnnotateWarpGroupRegAlloc(): + """Inject set_max_nreg calls into warp-specialized functions. + + This pass analyzes the function to collect register hints from set_max_nreg + and no_set_max_nreg calls, then injects appropriate set_max_nreg calls into + producer and consumer branches of warp-specialized code. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateWarpGroupRegAlloc() # type: ignore + + +def InjectTmaBarrier(): + """InjectTmaBarrier + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectTmaBarrier() # type: ignore + + +def InjectFenceProxy(): + """InjectFenceProxy + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectFenceProxy() # type: ignore + + +def LegalizeVectorizedLoop(): + """LegalizeLoopVectorize + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizeVectorizedLoop() # type: ignore + + +def LegalizeSafeMemoryAccess(): + """LegalizeLoopVectorize + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizeSafeMemoryAccess() # type: ignore + + +def MakePackedAPI(): + """MakePackedAPI + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MakePackedAPI() # type: ignore + + +def AnnotateDeviceRegions(): + """AnnotateDeviceRegions + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateDeviceRegions() # type: ignore + + +def SplitHostDevice(): + """Split host/device functions even for empty kernels. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.SplitHostDevice() # type: ignore + + +def AnnotateReadOnlyParams(): + """Annotate read-only handle parameters for PrimFuncs. + + Adds attribute `tl.readonly_param_indices` listing param indices that are + never written, enabling CUDA codegen to emit `const` qualifiers to unlock + read-only cache loads. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateReadOnlyParams() # type: ignore + + +def VectorizeLoop(enable_vectorize: bool = True): + """VectorizeLoop + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore + + +def InjectPTXAsyncCopy(): + """Rewrite global to shared memory copy on CUDA with asynchronous copy. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectPTXAsyncCopy() # type: ignore + + +def LowerDeviceStorageAccessInfo(): + """Lower attached storage access information on device. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + Note + ---- + Run this pass after all storage access analysis finish. + """ + return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore + + +def ConfigIndexBitwidth(): + """Config index bitwidth. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + ---- + """ + return _ffi_api.ConfigIndexBitwidth() # type: ignore + + +def FlattenBuffer(): + """FlattenBuffer + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.FlattenBuffer() # type: ignore + + +def EliminateStorageSyncForMBarrier(): + """EliminateStorageSyncForMBarrier""" + return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore + + +def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_bytes: int = 16): + """MergeSharedMemoryAllocations + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore + + +def LowerL2Persistent(): + """LowerL2Persistent""" + return _ffi_api.LowerL2Persistent() # type: ignore + + +def PersistThreadblock(): + """PersistThreadblock""" + return _ffi_api.PersistThreadblock() # type: ignore + + +def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16): + """AlignDynamicSharedMemoryAllocations + + Parameters + ---------- + align_bytes: int + The alignment bytes. + + Returns + ------- + """ + return _ffi_api.AlignDynamicSharedMemoryAllocations(align_bytes) # type: ignore + + +def LowerSharedBarrier(): + """LowerSharedBarrier""" + return _ffi_api.LowerSharedBarrier() # type: ignore + + +def PlanAndUpdateBufferAllocationLocation(): + """Plan and update buffer allocation locations within PrimFuncs. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore + + +def HoistNonRestrictParams(): + return _ffi_api.HoistNonRestrictParams() # type: ignore + + +def StorageRewrite(): + """StorageRewrite + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.StorageRewrite() # type: ignore + + +def LowerOpaqueBlock(): + """LowerOpaqueBlock""" + return _ffi_api.LowerOpaqueBlock() # type: ignore + + +def LowerThreadAllreduce(): + """LowerThreadAllreduce""" + return _ffi_api.LowerThreadAllreduce() # type: ignore + + +def LowerIntrin(): + """LowerIntrin""" + return _ffi_api.LowerIntrin() # type: ignore + + +def LowerDeviceKernelLaunch(): + """ + Create and return a transform pass that lowers device kernel launch constructs to target-specific IR. + + This pass transforms high-level device kernel launch and related intrinsics into lower-level + IR suitable for backend code generation and device-side lowering. + + Returns: + tvm.transform.Pass: The transform pass that performs device kernel launch lowering. + """ + return _ffi_api.LowerDeviceKernelLaunch() # type: ignore + + +def LowerSharedTmem(): + """LowerSharedTmem""" + return _ffi_api.LowerSharedTmem() # type: ignore + + +def LayoutReducer(): + """ + Return a TVM transform pass that performs layout reduction/normalization. + + This wrapper delegates to the underlying FFI implementation and returns a pass object suitable for use in a PassContext or pass pipeline. The pass is intended to simplify or reduce tensor/layout-related representations during relay/tile transformations. + + Returns: + The transform pass object produced by the FFI backend. + """ + return _ffi_api.LayoutReducer() # type: ignore diff --git a/tilelang/original/tilelang/transform/_ffi_api.py b/tilelang/original/tilelang/transform/_ffi_api.py new file mode 100644 index 0000000000000000000000000000000000000000..3692a32d64fe407e6a470278df375f280b75dcbf --- /dev/null +++ b/tilelang/original/tilelang/transform/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for tilelang""" + +import tvm_ffi + +# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); +tvm_ffi.init_ffi_api("tl.transform", __name__) diff --git a/tilelang/original/tilelang/transform/add_bufstore_wrapper.py b/tilelang/original/tilelang/transform/add_bufstore_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c1dd41e0dde17c7817416ea88fee9f63e17f1185 --- /dev/null +++ b/tilelang/original/tilelang/transform/add_bufstore_wrapper.py @@ -0,0 +1,154 @@ +from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm +from tvm.tir.stmt_functor import ir_transform, post_order_visit +from tvm.tir.transform import prim_func_pass + + +def AddWrapperForSingleBufStore(): + """ + Creates a TVM pass that wraps single buffer stores with parallel loops. + + This transformation adds T.Parallel wrappers around buffer stores that: + 1. Access fragment buffers with index 0 + 2. Are not inside existing tile operations or thread bindings + 3. Don't access fragment buffers with non-zero indices + + Returns: + A prim_func_pass that applies the transformation + """ + + def pass_fn(func: PrimFunc, mod, ctx): + # Counter for tracking nested tile operations + tile_operation_depth = 0 + # Set of variables bound to threads + thread_binding_vars = set() + + def get_used_variables(operation) -> set: + """ + Collects all variables used in the given operation. + + Args: + operation: The TIR operation to analyze + + Returns: + Set of variables used in the operation + """ + used_variables = set() + + def visit_variable(node): + if isinstance(node, Var): + used_variables.add(node) + + post_order_visit(operation, visit_variable) + return used_variables + + def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]: + """ + Categorizes buffers accessed in the statement by their scope. + + Args: + statement: The TIR statement to analyze + + Returns: + Tuple of (local_buffers, fragment_buffers) + """ + accessed_buffers = set() + + def visit_buffer_access(node): + if isinstance(node, (BufferLoad, BufferStore)): + accessed_buffers.add(node.buffer) + + post_order_visit(statement, visit_buffer_access) + + local_buffers = [] + fragment_buffers = [] + for buffer in accessed_buffers: + if buffer.scope() == "local.fragment": + fragment_buffers.append(buffer) + elif buffer.scope().startswith("local"): + local_buffers.append(buffer) + return local_buffers, fragment_buffers + + def collect_buffer_indices(statement) -> dict[Buffer, list[int]]: + """ + Maps each buffer to its access indices. + + Args: + statement: The TIR statement to analyze + + Returns: + Dictionary mapping buffers to their access indices + """ + buffer_to_indices = {} + + def visit_buffer_access(node): + if isinstance(node, (BufferLoad, BufferStore)): + buffer_to_indices[node.buffer] = node.indices + + post_order_visit(statement, visit_buffer_access) + return buffer_to_indices + + def is_tile_operation_loop(loop: For) -> bool: + """ + Determines if a For loop is a tile operation. + + Args: + loop: The For loop to check + + Returns: + True if the loop is a tile operation (parallel or has num_stages annotation) + """ + return loop.kind == ForKind.PARALLEL or "num_stages" in loop.annotations + + def pre_visit(statement): + """ + Pre-order visitor that tracks thread bindings and tile operation depth. + """ + nonlocal tile_operation_depth + + if isinstance(statement, AttrStmt) and statement.attr_key == "thread_extent": + thread_binding_vars.add(statement.node.var) + elif isinstance(statement, For) and is_tile_operation_loop(statement): + tile_operation_depth += 1 + + def post_visit(statement): + """ + Post-order visitor that applies transformations and updates counters. + """ + nonlocal tile_operation_depth + + if isinstance(statement, For) and is_tile_operation_loop(statement): + tile_operation_depth -= 1 + + elif isinstance(statement, BufferStore): + used_variables = get_used_variables(statement) + thread_bound_variables = used_variables.intersection(thread_binding_vars) + + # Only transform if not inside tile operations and no thread bindings + if tile_operation_depth == 0 and len(thread_bound_variables) == 0: + # Skip if no fragment buffers are accessed + _, fragment_buffers = collect_buffer_accesses(statement) + if len(fragment_buffers) == 0: + return statement + + # Validate fragment buffer indices - only index 0 is supported + buffer_indices = collect_buffer_indices(statement) + for buffer, indices in buffer_indices.items(): + if buffer.scope() != "local.fragment": + continue + for index in indices: + if isinstance(index, IntImm) and index != 0: + raise ValueError( + f"Fragment buffer access with non-zero index [{index}] is not supported. " + "Only fragment[0] access is allowed." + ) + + # Wrap fragment[0] access with T.Parallel loop + return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement) + + return statement + + new_body = ir_transform(func.body, pre_visit, post_visit) + + return func.with_body(new_body) + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/original/tilelang/transform/pass_config.py b/tilelang/original/tilelang/transform/pass_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb82106d4e69b5eeea16400b51ff8904031aa6e --- /dev/null +++ b/tilelang/original/tilelang/transform/pass_config.py @@ -0,0 +1,162 @@ +# TODO: Add more documentation for each pass config + +from enum import Enum + + +class PassConfigKey(str, Enum): + """Pass configuration keys for TileLang compiler.""" + + # TileLang specific configs + TL_SIMPLIFY = "tl.Simplify" + """Enable/disable TileLang simplification passes. Default: True""" + + TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized" + """Disable warp specialization optimization. Default: False""" + + TL_DISABLE_FAST_MATH = "tl.disable_fast_math" + """Disable fast math optimization. Default: True + will be deprecated in the 0.1.7 release + """ + + TL_ENABLE_FAST_MATH = "tl.enable_fast_math" + """ + Enable fast math optimization. Default: False + if enabled, --use_fast_math will be passed to nvcc + """ + + TL_PTXAS_REGISTER_USAGE_LEVEL = "tl.ptxas_register_usage_level" + """The PTXAS register usage level in [0, 10], which controls the + aggressiveness of optimizations that affect register usage. Default: None""" + + TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output" + """Enable ptxas verbose output. Default: False""" + + TL_DEVICE_COMPILE_FLAGS = "tl.device_compile_flags" + """Additional device compiler flags passed to nvcc/NVRTC. + + Accepts either a string (parsed with shell-like splitting) or a list of + strings. Typical usage is to provide extra include paths, defines or + ptxas options, e.g.: + + - "-I/opt/include -DMY_SWITCH=1 --ptxas-options=--verbose" + - ["-I/opt/include", "-DMY_SWITCH=1", "--ptxas-options=--verbose"] + + These flags are appended to the compiler options used in the tvm_ffi + CUDA compile callback. Default: None + """ + + TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth" + """Bitwidth for configuration indices. Default: 32""" + + TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower" + """Disable TMA (Tensor Memory Access) lowering. Default: False""" + + TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" + """Disable safe memory access optimization. Default: False""" + + TL_DISABLE_VECTORIZE_256 = "tl.disable_vectorize_256" + """Disable usage of LDG/STG 256. Default: False""" + TL_DISABLE_WGMMA = "tl.disable_wgmma" + """Disable usage of Hopper WGMMA. Default: False""" + + TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations" + """Enable debug information for merge shared memory allocations. Default: False""" + + TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge" + """Enable aggressive merge of shared memory allocations. Default: False""" + + TL_DISABLE_SHUFFLE_ELECT = "tl.disable_shuffle_elect" + """Disable shuffle election optimization. Default: False""" + + TL_DISABLE_THREAD_STORAGE_SYNC = "tl.disable_thread_storage_sync" + """Disable thread storage synchronization pass. When enabled, disables the + automatic insertion of thread synchronization barriers (e.g., __syncthreads()) + for shared memory access coordination. This can be useful for performance + optimization in cases where manual synchronization is preferred or when + synchronization is not needed. Default: False""" + + TL_FORCE_LET_INLINE = "tl.force_let_inline" + """Force TileLang to inline let bindings during simplification. Default: False""" + + TL_LAYOUT_VISUALIZATION_ENABLE = "tl.layout_visualization_enable" + """Enable layout inference visualization. Default: False""" + + TL_LAYOUT_VISUALIZATION_FORMATS = "tl.layout_visualization_formats" + """Layout visualization formats. + Acceptable values: "pdf", "png", "svg", "all" + + """ + + TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace" + """Control StorageRewrite inplace detection. + + When False (default) StorageRewrite keeps distinct temporaries for patterns + such as `dst[i] = f(src[i])`, avoiding implicit aliasing: + + ``` + read = T.allocate([1], T.int32, "local.var") + write = T.allocate([1], T.int32, "local.var") + read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var") + write_buf = T.Buffer((1,), T.int32, data=write, scope="local.var") + write_buf[0] = read_buf[0] * 2 + f(write_buf[0]) + ``` + + Setting the flag to True allows StorageRewrite to reuse the `read` buffer + for the write when it can prove the update is safely inplace, producing IR + like: + + ``` + read = T.allocate([1], T.int32, "local.var") + read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var") + read_buf[0] = read_buf[0] * 2 + f(read_buf[0]) + ``` + + This reduces local memory usage but introduces aliasing between the buffers. + + Usage: + + ```python + from tilelang.transform import PassContext, PassConfigKey + + with PassContext( + config={PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE.value: True} + ): + mod = tilelang.transform.StorageRewrite()(mod) + ``` + """ + + # TIR related configs + TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" + """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" + + TIR_DISABLE_CSE = "tir.disable_cse_tir" + """Disable TIR Common Subexpression Elimination. Default: False""" + + TIR_SIMPLIFY = "tir.Simplify" + """Enable/disable TIR simplification passes. Default: True""" + + TIR_DISABLE_STORAGE_REWRITE = "tir.disable_storage_rewrite" + """Disable storage rewrite optimization. Default: False""" + + TIR_DISABLE_VECTORIZE = "tir.disable_vectorize" + """Disable vectorization optimization. Default: False""" + + TIR_USE_ASYNC_COPY = "tir.use_async_copy" + """Enable asynchronous memory copy operations. Default: True""" + + TIR_ENABLE_DEBUG = "tir.enable_debug" + """Enable debug information in generated code. Default: False""" + + TIR_MERGE_STATIC_SMEM = "tir.merge_static_smem" + """Merge static shared memory allocations. Default: True""" + + TIR_ADD_LOWER_PASS = "tir.add_lower_pass" + """Additional lowering passes to be applied. Default: None""" + + TIR_NOALIAS = "tir.noalias" + """Enable pointer non-aliasing assumptions. Default: True""" + + CUDA_KERNELS_OUTPUT_DIR = "cuda.kernels_output_dir" + """Output directory for generated CUDA kernels. Default: empty string""" diff --git a/tilelang/original/tilelang/transform/simplify.py b/tilelang/original/tilelang/transform/simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e577d036a3b4553d66c19d1b04d79108249cbf --- /dev/null +++ b/tilelang/original/tilelang/transform/simplify.py @@ -0,0 +1,63 @@ +from __future__ import annotations +from tilelang import tvm as tvm +from tvm import IRModule +from tvm.tir import PrimFunc +from typing import Callable +from . import _ffi_api + + +def LetInline(): + """LetInline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LetInline() # type: ignore + + +def Simplify(simplify_arguments: bool = False): + """Simplify + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.Simplify(simplify_arguments) # type: ignore + + +def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule: + if isinstance(stmt, PrimFunc): + if inline_let: + mod = LetInline()(IRModule.from_expr(stmt)) + mod = Simplify(simplify_arguments=True)(mod) + else: + mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt)) + assert len(mod.functions) == 1, "Simplify should return a single function" + return list(mod.functions.values()).pop() + elif isinstance(stmt, IRModule): + if inline_let: + mod = LetInline()(stmt) + mod = Simplify(simplify_arguments=True)(mod) + else: + mod = Simplify(simplify_arguments=True)(stmt) + assert len(mod.functions) == 1, "Simplify should return a single function" + return list(mod.functions.values()).pop() + else: + raise ValueError(f"Unsupported type: {type(stmt)}") + + +# Decorator to simplify the output of a function +def simplify_prim_func(func: Callable) -> Callable: + def wrapper(*args, **kwargs): + stmt: PrimFunc | IRModule = (func)(*args, **kwargs) + return _Simplify(stmt) + + return wrapper + + +def apply_simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule: + """Apply Simplify pass to a PrimFunc or IRModule.""" + return _Simplify(stmt, inline_let) diff --git a/tilelang/original/tilelang/utils/__init__.py b/tilelang/original/tilelang/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df0c71c2e3a5fdd8f80943b158ef75077a73c1d6 --- /dev/null +++ b/tilelang/original/tilelang/utils/__init__.py @@ -0,0 +1,21 @@ +"""The profiler and convert to torch utils""" + +from .target import determine_target # noqa: F401 +from .tensor import TensorSupplyType, torch_assert_close, map_torch_type # noqa: F401 +from .language import ( + is_global, # noqa: F401 + is_shared, # noqa: F401 + is_shared_dynamic, # noqa: F401 + is_tensor_memory, # noqa: F401 + is_fragment, # noqa: F401 + is_local, # noqa: F401 + array_reduce, # noqa: F401 + retrieve_stride, # noqa: F401 + retrieve_shape, # noqa: F401 + retrive_ptr_from_buffer_region, # noqa: F401 + is_full_region, # noqa: F401 + to_buffer_region, # noqa: F401 + get_buffer_region_from_load, # noqa: F401 + get_prim_func_name, # noqa: F401 +) +from .deprecated import deprecated # noqa: F401 diff --git a/tilelang/original/tilelang/utils/deprecated.py b/tilelang/original/tilelang/utils/deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..2944f292b308dd88ae2006433e96b09e0af5069e --- /dev/null +++ b/tilelang/original/tilelang/utils/deprecated.py @@ -0,0 +1,39 @@ +def deprecated_warning(method_name: str, new_method_name: str, phaseout_version: str = None): + """A function to indicate that a method is deprecated""" + import warnings # pylint: disable=import-outside-toplevel, import-error + + warnings.warn( + f"{method_name} is deprecated, use {new_method_name} instead" + + (f" and will be removed in {phaseout_version}" if phaseout_version else ""), + DeprecationWarning, + stacklevel=2, + ) + + +def deprecated( + method_name: str, + new_method_name: str, + phaseout_version: str = None, +): + """A decorator to indicate that a method is deprecated + + Parameters + ---------- + method_name : str + The name of the method to deprecate + new_method_name : str + The name of the new method to use instead + phaseout_version : str + The version to phase out the method + """ + import functools # pylint: disable=import-outside-toplevel + + def _deprecate(func): + @functools.wraps(func) + def _wrapper(*args, **kwargs): + deprecated_warning(method_name, new_method_name, phaseout_version) + return func(*args, **kwargs) + + return _wrapper + + return _deprecate diff --git a/tilelang/original/tilelang/utils/device.py b/tilelang/original/tilelang/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2b6ca5aaf9fd4fe034914c6384f125c05340ea --- /dev/null +++ b/tilelang/original/tilelang/utils/device.py @@ -0,0 +1,21 @@ +import torch + +IS_CUDA = torch.cuda.is_available() + +IS_MPS = False +try: + IS_MPS = torch.backends.mps.is_available() +except AttributeError: + print("MPS backend is not available in this PyTorch build.") +except Exception as e: + print(f"An unexpected error occurred while checking MPS availability: {e}") + + +def get_current_device(): + device = None + if IS_CUDA: + device = torch.cuda.current_device() + elif IS_MPS: + device = "mps:0" + + return device diff --git a/tilelang/original/tilelang/utils/language.py b/tilelang/original/tilelang/utils/language.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8e588043a49df5e2006920ca89233eb82b5eda --- /dev/null +++ b/tilelang/original/tilelang/utils/language.py @@ -0,0 +1,504 @@ +from __future__ import annotations +from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr +from tilelang.language.utils import region as _make_region_call +from functools import reduce +from tvm import IRModule, DataType +from tvm.tir import PrimFunc +from tvm import ir, tir +# Scope Checkers for TVM Buffers +# These utility functions check the memory scope of a given TVM buffer. + + +def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> Buffer: + """ + Extract Buffer from Buffer, BufferLoad, or BufferRegion. + + Args: + buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion + + Returns: + Buffer: The underlying buffer object + """ + if isinstance(buffer_or_load_or_region, Buffer): + return buffer_or_load_or_region + elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)): + return buffer_or_load_or_region.buffer + else: + raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") + + +def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool: + """ + Check if the buffer is in the global memory scope. + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in global memory, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope() == "global" + + +def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = True) -> bool: + """ + Check if the buffer is in the shared memory scope. + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in shared memory, False otherwise. + """ + buffer = _get_buffer(buffer) + conditions = [False] + conditions.append(buffer.scope() == "shared") + if allow_dynamic: + conditions.append(is_shared_dynamic(buffer)) + return any(conditions) + + +def is_shared_dynamic(buffer: Buffer | BufferLoad | BufferRegion) -> bool: + """ + Check if the buffer is in the dynamic shared memory scope. + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in dynamic shared memory, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope() == "shared.dyn" + + +def is_tensor_memory(buffer: Buffer | BufferLoad | BufferRegion) -> bool: + """ + Check if the buffer is in tensor memory scope (e.g., shared.tmem). + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in tensor memory, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope().startswith("shared.tmem") + + +def is_local(buffer: Buffer | BufferLoad | BufferRegion) -> bool: + """ + Check if the buffer is in the local memory scope. + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in local memory, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope() == "local" + + +def is_fragment(buffer: Buffer | BufferLoad | BufferRegion) -> bool: + """ + Check if the buffer is a fragment (e.g., for matrix multiplication operations). + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is a fragment, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope().startswith("local.fragment") + + +def get_buffer_elems(buffer: Buffer) -> int: + """ + Get the number of elements in the buffer. + """ + return reduce(lambda x, y: x * y, buffer.shape) + + +def array_reduce(array: list[int]) -> int: + """ + Reduce an array of integers to a single integer. + + Args: + array (List[int]): The array of integers to reduce. + + Returns: + int: The reduced integer. + """ + return reduce(lambda x, y: x * y, array) + + +def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: + """ + Retrieve the single PrimFunc from an IRModule. + + Args: + ir_module (IRModule): The TVM IRModule to extract the function from. + The module should contain exactly one global function. + + Returns: + PrimFunc: The single function contained in the module. + + Raises: + ValueError: If ir_module is not an IRModule. + AssertionError: If the module contains more than one global function. + """ + if not isinstance(ir_module, IRModule): + raise ValueError("Not supported type: ", type(ir_module)) + assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule." + func = list(ir_module.functions.values())[0] + return func + + +def get_buffer_region_from_load(buffer_load: tir.BufferLoad, extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: + """ + Get the buffer region from a buffer load. + + May encounter buffer load like C[0:128, 0:32], ref to pull request + for buffer wise op: https://github.com/apache/tvm/pull/14693 + convert load to region + """ + buffer, indices = buffer_load.buffer, buffer_load.indices + regions = [] + found_ramp: bool = False + + if extents is not None: + assert len(extents) == len(indices), "extents should have the same length as indices" + for i, indice in enumerate(indices): + if isinstance(indice, tir.Ramp): + assert extents is None, "extents should be provided for BufferLoad with Ramp indices" + regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) + found_ramp = True + elif isinstance(indice, tir.PrimExpr): + if extents is not None: + regions.append(ir.Range.from_min_extent(indice, extents[i])) + found_ramp = True + else: + regions.append(ir.Range.from_min_extent(indice, 1)) + else: + raise ValueError(f"Unsupported type: {type(indice)} for index {i}") + if found_ramp: + return tir.BufferRegion(buffer, regions) + else: + return None + + +def to_buffer_region( + obj: Buffer | BufferLoad | BufferRegion | tir.Var, access_type: str = "rw", extents: list[PrimExpr] | None = None +) -> PrimExpr | BufferRegion: + """ + Convert to/from the tl.region representation. + + - Buffer/BufferLoad/BufferRegion -> returns a tl.region call (PrimExpr) + - tl.region Call -> returns the decoded BufferRegion for analysis + """ + from tilelang.language.frame import has_let_value, get_let_value + + if isinstance(obj, tir.Var) and has_let_value(obj): + obj = get_let_value(obj) + # Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis + if isinstance(obj, tir.BufferRegion): + if extents is None: + return obj + mins = [r.min for r in obj.region] + exts = [r.extent for r in obj.region] + assert len(extents) == len(exts) + exts = [tir.min(exts[i], extents[i]) for i in range(len(exts))] + return _make_region_call(tir.BufferLoad(obj.buffer, mins), access_type, *exts) + if isinstance(obj, tir.Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + if extents is None: + ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return tir.BufferRegion(obj, ranges) + exts = list(extents) + return _make_region_call(tir.BufferLoad(obj, mins), access_type, *exts) + if isinstance(obj, tir.BufferLoad): + if extents is None: + region = get_buffer_region_from_load(obj) + if region is not None: + return region + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return tir.BufferRegion(obj.buffer, ranges) + exts = list(extents) + if len(obj.indices) > len(exts): + exts = [tir.IntImm("int32", 1) for _ in range(len(obj.indices) - len(exts))] + exts + assert len(obj.indices) == len(exts) + return _make_region_call(obj, access_type, *exts) + raise ValueError(f"Unsupported argument type for to_buffer_region: {type(obj)}") + + +def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list: + """ + Retrieve shape-like extents for a buffer-like object. + + - Buffer -> its `shape` + - BufferRegion -> list of each range's `extent` + - BufferLoad -> extents from `get_buffer_region_from_load(obj)` + """ + if isinstance(obj, tir.Buffer): + return obj.shape + if isinstance(obj, tir.BufferRegion): + return [r.extent for r in obj.region] + if isinstance(obj, tir.BufferLoad): + region = get_buffer_region_from_load(obj) + if region is None: + raise ValueError("Cannot retrieve shape from scalar BufferLoad without region") + return [r.extent for r in region.region] + raise ValueError(f"Unsupported retrieve_shape argument type: {type(obj)} for object {obj}") + + +def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list: + """ + Retrieve row-major strides for a buffer-like object based on its buffer.shape. + + For BufferRegion and BufferLoad, uses the underlying buffer's `shape`. + """ + if isinstance(obj, tir.Buffer): + shape = obj.shape + elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)): + shape = obj.buffer.shape + else: + raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}") + + strides = [] + stride = 1 + for s in reversed(shape): + strides.insert(0, stride) + stride *= s + return strides + + +def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, access_type: str = "r") -> PrimExpr: + if isinstance(buffer_or_load_or_region, Buffer): + return buffer_or_load_or_region.access_ptr(access_type) + elif isinstance(buffer_or_load_or_region, BufferLoad): + buffer_load = buffer_or_load_or_region + offset, stride = 0, 1 + buffer = buffer_load.buffer + for i, shape in enumerate(reversed(buffer.shape)): + indice = buffer_load.indices[len(buffer_load.indices) - i - 1] + if isinstance(indice, (tir.IntImm, tir.PrimExpr)): + offset += indice * stride + elif isinstance(indice, tir.Ramp): + offset += indice.base * stride + else: + raise ValueError(f"Unsupported index type: {type(indice)}") + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + elif isinstance(buffer_or_load_or_region, BufferRegion): + buffer_region = buffer_or_load_or_region + buffer = buffer_region.buffer + offset, stride = 0, 1 + for i, shape in enumerate(reversed(buffer.shape)): + offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + else: + raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + + +def retrieve_ptr( + obj: Buffer | BufferRegion | BufferLoad, + access_type: str = "r", + ignore_last_ndim: int = 0, +) -> PrimExpr: + """ + Retrieve a pointer to the start of a (possibly sliced) buffer region. + + - Buffer -> base pointer + - BufferRegion -> pointer with byte offset computed from region minima + - BufferLoad -> pointer offset computed from indices or derived region + + Args: + obj: Buffer-like object + access_type: TVM Buffer access mask, e.g. "r", "w", "rw" + ignore_last_ndim: do not offset the last N dimensions + """ + if isinstance(obj, tir.Buffer): + return obj.access_ptr(access_type) + + if isinstance(obj, tir.BufferRegion): + buffer, region = obj.buffer, obj.region + strides = retrieve_stride(obj) + # offset only over the leading dims, optionally ignoring the tail dims + upto = max(0, len(region) - int(ignore_last_ndim)) + offset = 0 + for i in range(upto): + offset += region[i].min * strides[i] + return buffer.access_ptr(access_type, offset=offset) + + if isinstance(obj, tir.BufferLoad): + buffer = obj.buffer + region = get_buffer_region_from_load(obj) + if region is not None: + mins = [r.min for r in region.region] + else: + mins = list(obj.indices) + strides = retrieve_stride(obj) + upto = max(0, len(mins) - int(ignore_last_ndim)) + offset = 0 + for i in range(upto): + offset += mins[i] * strides[i] + return buffer.access_ptr(access_type, offset=offset) + + raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}") + + +def retrieve_offset(obj: Buffer | BufferRegion | BufferLoad) -> list: + """ + Retrieve per-dimension minima offsets. + + - Buffer -> [0, 0, ...] + - BufferRegion -> [r.min for r in region] + - BufferLoad -> indices (or derived region minima) + """ + if isinstance(obj, tir.Buffer): + return [0] * len(obj.shape) + if isinstance(obj, tir.BufferRegion): + return [r.min for r in obj.region] + if isinstance(obj, tir.BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return [r.min for r in region.region] + return list(obj.indices) + raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}") + + +def bits_product(shape: list[PrimExpr], dtype: str) -> PrimExpr: + """ + Compute the number of bits in a Buffer (shape with dtype).""" + if len(shape) == 0: + return tir.IntImm("int32", 1) + result = shape[0] + for i in range(1, len(shape)): + result = result * shape[i] + return result * DataType(dtype).bits + + +def prim_expr_equal(lhs, rhs) -> bool: + """ + Robust equality for PrimExpr shapes/extents. + + Tries structural_equal first, then falls back to expr_deep_equal. + Python ints are converted to IntImm for comparison. + """ + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs == rhs + if isinstance(lhs, int): + lhs = tir.IntImm("int32", lhs) + if isinstance(rhs, int): + rhs = tir.IntImm("int32", rhs) + if ir.structural_equal(lhs, rhs): + return True + return tir.analysis.expr_deep_equal(lhs, rhs) + + +def legalize_pairwise_extents(src_extents: list, dst_extents: list) -> tuple[list, list]: + """ + Right-align and broadcast two extent lists to be mutually compatible. + + Early-exit rule: + - If the number of non-1 dimensions in `src_extents` equals that in `dst_extents`, + no adjustment is made; the original extents are returned unchanged. This + preserves the per-dimension iteration mapping (one loop var per non-1 dim) + and avoids creating extra varying axes on either side. + + Otherwise, for each pair of tail-aligned dimensions (x, y): + - if x == y: keep both + - elif x == 1: set x = y + - elif y == 1: set y = x + - else: promote both to tir.max(x, y) to handle dynamic-vs-static safely + + Leading unmatched dimensions are kept as-is. + + Returns a tuple of new lists (src_new, dst_new). + """ + a = list(src_extents) + b = list(dst_extents) + + # If both sides have the same number of non-1 extents, don't re-broadcast. + def _num_non_one(exts: list) -> int: + return sum(0 if prim_expr_equal(x, 1) else 1 for x in exts) + + if _num_non_one(a) == _num_non_one(b): + return a, b + k = min(len(a), len(b)) + for i in range(1, k + 1): + x, y = a[-i], b[-i] + if prim_expr_equal(x, y): + continue + elif prim_expr_equal(x, 1): + a[-i] = y + elif prim_expr_equal(y, 1): + b[-i] = x + else: + # Dynamic mismatch: promote to max so downstream clamping/predicates remain safe + m = tir.max(x, y) + a[-i] = m + b[-i] = m + return a, b + + +def is_full_region(buffer_region: BufferRegion) -> bool: + """ + Check whether a BufferRegion covers the full buffer region. + + A full region means each dimension has start 0 and extent equal to + the corresponding dimension in the buffer's shape. + + Args: + buffer_region: The TVM BufferRegion to check. + + Returns: + bool: True if the region is full; otherwise False. + """ + if not isinstance(buffer_region, tir.BufferRegion): + raise TypeError(f"Expected BufferRegion, got {type(buffer_region)}") + + buf = buffer_region.buffer + ranges = buffer_region.region + + if len(buf.shape) != len(ranges): + return False + + expr_equal = tir.analysis.expr_deep_equal + for dim, r in zip(buf.shape, ranges): + # start == 0 and extent == shape + if not expr_equal(r.min, 0): + return False + if not expr_equal(r.extent, dim): + return False + return True + + +def get_prim_func_name(func: PrimFunc | None, default: str | None = None) -> str | None: + """ + Extract a human‑readable function name from a TVM PrimFunc. + + Prefer the `global_symbol` attribute set on the PrimFunc. If it is missing + (e.g., private PrimFunc without a global symbol), return the provided + `default` value. + + Args: + func: TVM PrimFunc instance or None. + default: Fallback name to return when no name can be determined. + + Returns: + The function name as a string, or `default` when unavailable. + """ + if func is None: + return default + try: + name = func.attrs["global_symbol"] + return str(name) if name is not None else default + except Exception: + return default diff --git a/tilelang/original/tilelang/utils/sparse.py b/tilelang/original/tilelang/utils/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..fa227b07dd503bceccbf0f142c86a9383b68fbf6 --- /dev/null +++ b/tilelang/original/tilelang/utils/sparse.py @@ -0,0 +1,170 @@ +from __future__ import annotations +import os +import torch +import warnings +from tilelang.contrib import nvcc +from tilelang.utils.tensor import is_float8_dtype, fp8_remove_negative_zeros_ +from torch.utils.cpp_extension import load, _import_module_from_library +from tilelang import env + +# Include version information to ensure different versions use separate caches +from tilelang import __version__ + +# Define paths +compress_util = os.path.join(env.TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu") +# Cache directory for compiled extensions +_CACHE_DIR = os.path.join(env.TILELANG_CACHE_DIR, "sparse_compressor", __version__) +os.makedirs(_CACHE_DIR, exist_ok=True) + + +def _get_cached_lib(): + name = "compress_lib" + + if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")): + try: + return _import_module_from_library(name, _CACHE_DIR, is_python_module=True) + except Exception: + pass + + # Set TORCH_CUDA_ARCH_LIST + env._initialize_torch_cuda_arch_flags() + + # Compile if not cached or loading failed + return load( + name=name, + sources=[compress_util], + extra_cuda_cflags=[ + "-O2", + "-std=c++17", + "-lineinfo", + f"-I{env.CUTLASS_INCLUDE_DIR}", + f"-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include", + "-arch=sm_90", + ], + build_directory=_CACHE_DIR, + ) + + +def compress_sm90(A: torch.Tensor, block_k: int, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: + if block_k > 128: + block_k = 128 + # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 + warnings.warn(f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) + # Load the library (will use cache if available) + compress_lib = _get_cached_lib() + + return compress_lib.compress_sm90(A, block_k, transposed) + + +def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: + try: + from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + except ImportError as err: + raise ImportError( + "SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version." + ) from err + orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS + try: + SparseSemiStructuredTensor._FORCE_CUTLASS = True + if transposed is not False: + raise NotImplementedError("transposed flag is deprecated by pytorch") + compressed = to_sparse_semi_structured(A) + return compressed.packed, compressed.meta + finally: + SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val + + +def compress(A: torch.Tensor, transposed: bool, arch: str | None = None, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compress a tensor using the appropriate method based on the CUDA architecture. + """ + if arch is None: + arch = nvcc.get_target_compute_version() + + compute_version = nvcc.parse_compute_version(arch) + + if compute_version >= (9, 0): + return compress_sm90(A, transposed=transposed, **kwargs) + elif compute_version >= (8, 0): + if transposed: + A = A.t().contiguous() + origin_dtype = A.dtype + if is_float8_dtype(origin_dtype): + fp8_remove_negative_zeros_(A) + A = A.view(torch.int8) + A_sp, E = compress_sm80(A, transposed=False) + if is_float8_dtype(origin_dtype): + A_sp = A_sp.view(origin_dtype) + if transposed: + A_sp = A_sp.t().contiguous() + return A_sp, E + else: + raise ValueError(f"Unsupported CUDA compute version: {compute_version}. Supported versions are sm_80 and sm_90.") + + +def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): + """ + Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. + Args: + M (int): Number of rows + K (int): Number of columns + dtype: Data type of the tensor + device: Device to create the tensor on + transposed (bool): If True, returns a transposed tensor of shape (K, M) + """ + elem, group = 2, 4 + if dtype == torch.float32: + elem, group = 1, 2 + tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group) + indice = tensor.topk(elem, dim=-1).indices + tensor.scatter_(-1, indice, 0) + tensor = tensor.view(M, K) + if transposed: + tensor = tensor.t().contiguous() + return tensor.to(dtype) # dtype like float8 might not have randn kernel + + +def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, device="cuda", transposed: bool = False): + """ + Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension. + Args: + M (int): Number of rows + K (int): Number of columns + low (int): Lower bound of the random integers + high (int): Upper bound of the random integers + dtype: Data type of the tensor + device: Device to create the tensor on + transposed (bool): If True, returns a transposed tensor of shape (K, M) + """ + elem, group = 2, 4 + if dtype == torch.float32: + elem, group = 1, 2 + tensor = torch.randint(low, high, (M, K), dtype=dtype, device=device).view(M, -1, group) + indice = tensor.topk(elem, dim=-1).indices + tensor.scatter_(-1, indice, 0) + tensor = tensor.view(M, K) + if transposed: + tensor = tensor.t().contiguous() + return tensor + + +def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): + """ + Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. + Args: + M (int): Number of rows + K (int): Number of columns + dtype: Data type of the tensor + device: Device to create the tensor on + transposed (bool): If True, returns a transposed tensor of shape (K, M) + """ + elem, group = 2, 4 + if dtype == torch.float32: + elem, group = 1, 2 + tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group) + indice = tensor.topk(elem, dim=-1).indices + tensor.scatter_(-1, indice, 0) + tensor = tensor.view(M, K) + if transposed: + tensor = tensor.t().contiguous() + return tensor diff --git a/tilelang/original/tilelang/utils/target.py b/tilelang/original/tilelang/utils/target.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b88f5e8c7fbe9fb1f10642924fb05206621971 --- /dev/null +++ b/tilelang/original/tilelang/utils/target.py @@ -0,0 +1,185 @@ +from __future__ import annotations +from platform import mac_ver +from typing import Literal +from tilelang import tvm as tvm +from tilelang import _ffi_api +from tvm.target import Target +from tvm.contrib import rocm +from tilelang.contrib import nvcc + +SUPPORTED_TARGETS: dict[str, str] = { + "auto": "Auto-detect CUDA/HIP/Metal based on availability.", + "cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).", + "hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).", + "metal": "Apple Metal target for arm64 Macs.", + "llvm": "LLVM CPU target (accepts standard TVM LLVM options).", + "webgpu": "WebGPU target for browser/WebGPU runtimes.", + "c": "C source backend.", + "cutedsl": "CuTe DSL GPU target.", +} + + +def describe_supported_targets() -> dict[str, str]: + """ + Return a mapping of supported target names to usage descriptions. + """ + return dict(SUPPORTED_TARGETS) + + +def check_cuda_availability() -> bool: + """ + Check if CUDA is available on the system by locating the CUDA path. + Returns: + bool: True if CUDA is available, False otherwise. + """ + try: + nvcc.find_cuda_path() + return True + except Exception: + return False + + +def check_hip_availability() -> bool: + """ + Check if HIP (ROCm) is available on the system by locating the ROCm path. + Returns: + bool: True if HIP is available, False otherwise. + """ + try: + rocm.find_rocm_path() + return True + except Exception: + return False + + +def check_metal_availability() -> bool: + mac_release, _, arch = mac_ver() + if not mac_release: + return False + # todo: check torch version? + return arch == "arm64" + + +def determine_target(target: str | Target | Literal["auto"] = "auto", return_object: bool = False) -> str | Target: + """ + Determine the appropriate target for compilation (CUDA, HIP, or manual selection). + + Args: + target (Union[str, Target, Literal["auto"]]): User-specified target. + - If "auto", the system will automatically detect whether CUDA or HIP is available. + - If a string or Target, it is directly validated. + + Returns: + Union[str, Target]: The selected target ("cuda", "hip", or a valid Target object). + + Raises: + ValueError: If no CUDA or HIP is available and the target is "auto". + AssertionError: If the target is invalid. + """ + + return_var: str | Target = target + + if target == "auto": + target = tvm.target.Target.current(allow_none=True) + if target is not None: + return target + # Check for CUDA and HIP availability + is_cuda_available = check_cuda_availability() + is_hip_available = check_hip_availability() + + # Determine the target based on availability + if is_cuda_available: + return_var = "cuda" + elif is_hip_available: + return_var = "hip" + elif check_metal_availability(): + return_var = "metal" + else: + raise ValueError("No CUDA or HIP or MPS available on this system.") + elif isinstance(target, str) and target.startswith("cutedsl"): + cuda_target_str = target.replace("cutedsl", "cuda", 1) + temp_target = Target(cuda_target_str) + + target_dict = dict(temp_target.export()) + target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"] + + return_var = Target(target_dict) + else: + # Validate the target if it's not "auto" + if isinstance(target, Target): + return_var = target + elif isinstance(target, str): + normalized_target = target.strip() + if not normalized_target: + raise AssertionError(f"Target {target} is not supported") + try: + Target(normalized_target) + except Exception as err: + examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS) + raise AssertionError( + f"Target {target} is not supported. Supported targets include: {examples}. " + "Pass additional options after the base name, e.g. `cuda -arch=sm_80`." + ) from err + return_var = normalized_target + else: + raise AssertionError(f"Target {target} is not supported") + + if isinstance(return_var, Target): + return return_var + if return_object: + if isinstance(return_var, Target): + return return_var + return Target(return_var) + return return_var + + +def target_is_cuda(target: Target) -> bool: + return _ffi_api.TargetIsCuda(target) + + +def target_is_hip(target: Target) -> bool: + return _ffi_api.TargetIsRocm(target) + + +def target_is_volta(target: Target) -> bool: + return _ffi_api.TargetIsVolta(target) + + +def target_is_turing(target: Target) -> bool: + return _ffi_api.TargetIsTuring(target) + + +def target_is_ampere(target: Target) -> bool: + return _ffi_api.TargetIsAmpere(target) + + +def target_is_hopper(target: Target) -> bool: + return _ffi_api.TargetIsHopper(target) + + +def target_is_sm120(target: Target) -> bool: + return _ffi_api.TargetIsSM120(target) + + +def target_is_cdna(target: Target) -> bool: + return _ffi_api.TargetIsCDNA(target) + + +def target_has_async_copy(target: Target) -> bool: + return _ffi_api.TargetHasAsyncCopy(target) + + +def target_has_ldmatrix(target: Target) -> bool: + return _ffi_api.TargetHasLdmatrix(target) + + +def target_has_stmatrix(target: Target) -> bool: + return _ffi_api.TargetHasStmatrix(target) + + +def target_has_bulk_copy(target: Target) -> bool: + return _ffi_api.TargetHasBulkCopy(target) + + +def target_get_warp_size(target: Target) -> int: + return _ffi_api.TargetGetWarpSize(target) diff --git a/tilelang/original/tilelang/utils/tensor.py b/tilelang/original/tilelang/utils/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..13ce194fb296e34f26a3c37de803e9ae93aed5a0 --- /dev/null +++ b/tilelang/original/tilelang/utils/tensor.py @@ -0,0 +1,319 @@ +"""The profiler and convert to torch utils""" + +from enum import Enum +import torch +from tvm import tir +import numpy as np + + +def is_float8_dtype(dtype: torch.dtype) -> bool: + return dtype in { + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + } + + +def fp8_remove_negative_zeros_(tensor: torch.Tensor): + assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype" + bits = tensor.view(torch.uint8) + zeros_mask = tensor == 0 + bits[zeros_mask] = 0x00 + + +class TensorSupplyType(Enum): + Integer = 1 + Uniform = 2 + Normal = 3 + Randn = 4 + Zero = 5 + One = 6 + Auto = 7 + + +def map_torch_type(intype) -> torch.dtype: + # Convert to string if needed + if not isinstance(intype, str): + intype = str(intype) + + if intype == "float8_e4m3": + assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0" + return torch.float8_e4m3fn + elif intype == "float8_e5m2": + assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0" + return torch.float8_e5m2 + elif intype == "e4m3fnuz_float8": + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0" + ) + return torch.float8_e4m3fnuz + elif intype == "float8_e8m0fnu": + assert hasattr(torch, "float8_e8m0fnu"), ( + "torch.float8_e8m0fnu is not supported in this version of torchPlease upgrade torch >= 2.8.0" + ) + return torch.float8_e8m0fnu + elif intype == "float4_e2m1fnx2": + assert hasattr(torch, "float4_e2m1fnx2"), ( + "torch.float4_e2m1fnx2 is not supported in this version of torchPlease upgrade torch >= 2.8.0" + ) + return torch.float4_e2m1fnx2 + elif "float4" in intype: + # PyTorch doesn't support float4, use int8 as storage type + return torch.int8 + else: + return getattr(torch, intype) + + +def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): + from tilelang.engine.param import KernelParam + from .device import get_current_device + + def get_tensor(param: KernelParam) -> torch.Tensor: + # Convert tvm.DataType to torch.dtype for tensor creation + dtype: torch.dtype = param.torch_dtype() + device = get_current_device() + + if hasattr(param, "shape") and not param.shape: + raise ValueError( + f"TensorType must have a shape, but got {type(param)}, " + "likely you are trying to generate a random tensor with a dynamic symbolic shape." + ) + + # Check if with dynamic symbolic shape + for shape in param.shape: + if isinstance(shape, tir.Var): + raise ValueError( + f"TensorType must have a static shape, but got {shape}, " + "likely you are trying to generate a random tensor with a dynamic symbolic shape." + ) + + shape = list(map(int, param.shape)) + if supply_type == TensorSupplyType.Auto: + is_unsigned = param.is_unsigned() + is_float8 = param.is_float8() + is_float4 = param.is_float4() + is_boolean = param.is_boolean() + if is_unsigned: + return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) + elif is_float8: + return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) + elif is_float4: + return torch.randint(low=0, high=16, size=shape, device=device, dtype=dtype) + elif is_boolean: + return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) + elif dtype in {torch.float16, torch.float32, torch.bfloat16}: + return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) + else: + return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) + + if dtype == torch.int8 and supply_type in [ + TensorSupplyType.Uniform, + TensorSupplyType.Normal, + ]: + return torch.ones(*shape, device=device, dtype=dtype) + + if supply_type == TensorSupplyType.Integer: + is_unsigned = param.is_unsigned() + is_float8 = param.is_float8() + is_float4 = param.is_float4() + is_boolean = param.is_boolean() + if is_unsigned: + return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) + elif is_float8: + return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) + elif is_float4: + return torch.randint(low=0, high=16, size=shape, device=device, dtype=dtype) + elif is_boolean: + return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) + else: + return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) + elif supply_type == TensorSupplyType.Uniform: + return torch.empty(*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) + elif supply_type == TensorSupplyType.Normal: + return torch.empty(*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype) + elif supply_type == TensorSupplyType.Randn: + return torch.randn(*shape, device=device).to(dtype) + elif supply_type == TensorSupplyType.Zero: + return torch.zeros(*shape, device=device, dtype=dtype) + elif supply_type == TensorSupplyType.One: + return torch.ones(*shape, device=device, dtype=dtype) + else: + raise NotImplementedError(supply_type) + + return get_tensor + + +# Adapted from https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py +def _compare_attributes( + actual: torch.Tensor, + expected: torch.Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +) -> None: + """Checks if the attributes of two tensors match. + Always checks + - the :attr:`~torch.Tensor.shape`, + - whether both inputs are quantized or not, + - and if they use the same quantization scheme. + Checks for + - :attr:`~torch.Tensor.layout`, + - :meth:`~torch.Tensor.stride`, + - :attr:`~torch.Tensor.device`, and + - :attr:`~torch.Tensor.dtype` + are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair. + """ + + def raise_mismatch_error(attribute_name: str, actual_value, expected_value): + raise AssertionError(f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.") + + if actual.shape != expected.shape: + raise_mismatch_error("shape", actual.shape, expected.shape) + if actual.is_quantized != expected.is_quantized: + raise_mismatch_error("is_quantized", actual.is_quantized, expected.is_quantized) + elif actual.is_quantized and actual.qscheme() != expected.qscheme(): + raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme()) + if actual.layout != expected.layout: + if check_layout: + raise_mismatch_error("layout", actual.layout, expected.layout) + elif actual.layout == torch.strided and check_stride and actual.stride() != expected.stride(): + raise_mismatch_error("stride()", actual.stride(), expected.stride()) + if check_device and actual.device != expected.device: + raise_mismatch_error("device", actual.device, expected.device) + if check_dtype and actual.dtype != expected.dtype: + raise_mismatch_error("dtype", actual.dtype, expected.dtype) + + +def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Equalizes some attributes of two tensors for value comparison. + If ``actual`` and ``expected`` are ... + - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. + - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`). + - ... not of the same ``layout``, they are converted to strided tensors. + Args: + actual (Tensor): Actual tensor. + expected (Tensor): Expected tensor. + Returns: + (Tuple[Tensor, Tensor]): Equalized tensors. + """ + # The comparison logic uses operators currently not supported by the MPS backends. + # See https://github.com/pytorch/pytorch/issues/77144 for details. + # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend + if actual.is_mps or expected.is_mps: # type: ignore[attr-defined] + actual = actual.cpu() + expected = expected.cpu() + if actual.device != expected.device: + actual = actual.cpu() + expected = expected.cpu() + if actual.dtype != expected.dtype: + actual_dtype = actual.dtype + expected_dtype = expected.dtype + # For uint64, this is not sound in general, which is why promote_types doesn't + # allow it, but for easy testing, we're unlikely to get confused + # by large uint64 overflowing into negative int64 + if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]: + actual_dtype = torch.int64 + if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]: + expected_dtype = torch.int64 + dtype = torch.promote_types(actual_dtype, expected_dtype) + actual = actual.to(dtype) + expected = expected.to(dtype) + if actual.layout != expected.layout: + # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided + actual = actual.to_dense() if actual.layout != torch.strided else actual + expected = expected.to_dense() if expected.layout != torch.strided else expected + return actual, expected + + +def torch_assert_close( + tensor_a, + tensor_b, + rtol=1e-2, + atol=1e-3, + max_mismatched_ratio=0.001, + verbose: bool = False, + equal_nan: bool = True, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + base_name: str = "LHS", + ref_name: str = "RHS", +): + """ + Custom function to assert that two tensors are "close enough," allowing a specified + percentage of mismatched elements. + + Parameters: + ---------- + tensor_a : torch.Tensor + The first tensor to compare. + tensor_b : torch.Tensor + The second tensor to compare. + rtol : float, optional + Relative tolerance for comparison. Default is 1e-2. + atol : float, optional + Absolute tolerance for comparison. Default is 1e-3. + max_mismatched_ratio : float, optional + Maximum ratio of mismatched elements allowed (relative to the total number of elements). + Default is 0.001 (0.1% of total elements). + + Raises: + ------- + AssertionError: + If the ratio of mismatched elements exceeds `max_mismatched_ratio`. + """ + + _compare_attributes( + tensor_a, tensor_b, check_device=check_device, check_dtype=check_dtype, check_layout=check_layout, check_stride=check_stride + ) + tensor_a, tensor_b = _equalize_attributes(tensor_a, tensor_b) + + mismatched = ~torch.isclose(tensor_a, tensor_b, rtol=rtol, atol=atol, equal_nan=equal_nan) + # Compute the absolute difference between the two tensors + diff = torch.abs(tensor_a - tensor_b) + # Count the number of mismatched elements + num_mismatched = mismatched.sum().item() + + # Calculate the total number of elements in the tensor + total_elements = tensor_a.numel() + + # Compute the allowed mismatched elements based on the ratio + max_allowed_mismatched = int(total_elements * max_mismatched_ratio) + + # Print debug information about the mismatch + if verbose: + print(f"Number of mismatched elements: {num_mismatched} / {total_elements} (allowed: {max_allowed_mismatched})") + + # If there are mismatched elements, print the first mismatch + if num_mismatched > 0: + # Find the first mismatch index + flat_idx = torch.argmax(mismatched.view(-1).int()).item() + idx = np.unravel_index(flat_idx, tensor_a.shape) + idx = [int(i) for i in idx] + a_val = tensor_a.reshape(-1)[flat_idx].item() + b_val = tensor_b.reshape(-1)[flat_idx].item() + abs_diff = abs(a_val - b_val) + rel_diff = abs_diff / (abs(b_val) + 1e-12) + mismatch_info = ( + f"\nFirst mismatch at index {idx}: lhs={a_val:.6f}, rhs={b_val:.6f}, abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}" + ) + else: + mismatch_info = "" + + # Modify the exception information + if num_mismatched > max_allowed_mismatched: + raise AssertionError( + f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} " + f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)." + f"{mismatch_info}" + f"\nGreatest absolute difference: {diff.max().item()}, " + f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}" + f"\n{base_name}: {tensor_a}" + f"\n{ref_name}: {tensor_b}" + ) + else: + return True